forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_redis_cluster
This commit is contained in:
commit
68cb5cae58
56 changed files with 2079 additions and 411 deletions
|
@ -282,7 +282,7 @@ jobs:
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install openai
|
pip install "openai==1.40.0"
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
|
|
|
@ -13,8 +13,9 @@ spec:
|
||||||
{{- include "litellm.selectorLabels" . | nindent 6 }}
|
{{- include "litellm.selectorLabels" . | nindent 6 }}
|
||||||
template:
|
template:
|
||||||
metadata:
|
metadata:
|
||||||
{{- with .Values.podAnnotations }}
|
|
||||||
annotations:
|
annotations:
|
||||||
|
checksum/config: {{ include (print $.Template.BasePath "/configmap-litellm.yaml") . | sha256sum }}
|
||||||
|
{{- with .Values.podAnnotations }}
|
||||||
{{- toYaml . | nindent 8 }}
|
{{- toYaml . | nindent 8 }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
labels:
|
labels:
|
||||||
|
|
|
@ -81,6 +81,7 @@ Works for:
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# add to env var
|
# add to env var
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
|
|
@ -8,6 +8,7 @@ liteLLM supports:
|
||||||
|
|
||||||
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
|
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
|
||||||
- [Langfuse](https://langfuse.com/docs)
|
- [Langfuse](https://langfuse.com/docs)
|
||||||
|
- [LangSmith](https://www.langchain.com/langsmith)
|
||||||
- [Helicone](https://docs.helicone.ai/introduction)
|
- [Helicone](https://docs.helicone.ai/introduction)
|
||||||
- [Traceloop](https://traceloop.com/docs)
|
- [Traceloop](https://traceloop.com/docs)
|
||||||
- [Lunary](https://lunary.ai/docs)
|
- [Lunary](https://lunary.ai/docs)
|
||||||
|
|
|
@ -56,7 +56,7 @@ response = litellm.completion(
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
### Set Langsmith fields - Custom Projec, Run names, tags
|
### Set Langsmith fields
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -77,7 +77,15 @@ response = litellm.completion(
|
||||||
metadata={
|
metadata={
|
||||||
"run_name": "litellmRUN", # langsmith run name
|
"run_name": "litellmRUN", # langsmith run name
|
||||||
"project_name": "litellm-completion", # langsmith project name
|
"project_name": "litellm-completion", # langsmith project name
|
||||||
"tags": ["model1", "prod-2"] # tags to log on langsmith
|
"run_id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", # langsmith run id
|
||||||
|
"parent_run_id": "f8faf8c1-9778-49a4-9004-628cdb0047e5", # langsmith run parent run id
|
||||||
|
"trace_id": "df570c03-5a03-4cea-8df0-c162d05127ac", # langsmith run trace id
|
||||||
|
"session_id": "1ffd059c-17ea-40a8-8aef-70fd0307db82", # langsmith run session id
|
||||||
|
"tags": ["model1", "prod-2"], # langsmith run tags
|
||||||
|
"metadata": { # langsmith run metadata
|
||||||
|
"key1": "value1"
|
||||||
|
},
|
||||||
|
"dotted_order": "20240429T004912090000Z497f6eca-6276-4993-bfeb-53cbbbba6f08"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# [BETA] Vertex AI Endpoints (Pass-Through)
|
# [BETA] Vertex AI Endpoints (Pass-Through)
|
||||||
|
|
||||||
Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation).
|
Use VertexAI SDK to call endpoints on LiteLLM Gateway (native provider format)
|
||||||
|
|
||||||
:::tip
|
:::tip
|
||||||
|
|
||||||
|
@ -40,16 +44,119 @@ litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
#### 3. Test it
|
#### 3. Test it
|
||||||
|
|
||||||
```shell
|
```python
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:countTokens \
|
import vertexai
|
||||||
-H "Content-Type: application/json" \
|
from google.auth.credentials import Credentials
|
||||||
-H "Authorization: Bearer sk-1234" \
|
from vertexai.generative_models import GenerativeModel
|
||||||
-d '{"instances":[{"content": "gm"}]}'
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
|
||||||
|
response = model.generate_content(
|
||||||
|
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.text)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage Examples
|
## Usage Examples
|
||||||
|
|
||||||
### Gemini API (Generate Content)
|
### Gemini API (Generate Content)
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
from vertexai.generative_models import GenerativeModel
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
|
||||||
|
response = model.generate_content(
|
||||||
|
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="Curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -57,8 +164,77 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
### Embeddings API
|
### Embeddings API
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
from vertexai.generative_models import GenerativeModel
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_text(
|
||||||
|
texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
|
||||||
|
task: str = "RETRIEVAL_DOCUMENT",
|
||||||
|
model_name: str = "text-embedding-004",
|
||||||
|
dimensionality: Optional[int] = 256,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Embeds texts with a pre-trained, foundational model."""
|
||||||
|
model = TextEmbeddingModel.from_pretrained(model_name)
|
||||||
|
inputs = [TextEmbeddingInput(text, task) for text in texts]
|
||||||
|
kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
|
||||||
|
embeddings = model.get_embeddings(inputs, **kwargs)
|
||||||
|
return [embedding.values for embedding in embeddings]
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -66,8 +242,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-geck
|
||||||
-d '{"instances":[{"content": "gm"}]}'
|
-d '{"instances":[{"content": "gm"}]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Imagen API
|
### Imagen API
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.preview.vision_models import ImageGenerationModel
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
|
||||||
|
|
||||||
|
images = model.generate_images(
|
||||||
|
prompt=prompt,
|
||||||
|
# Optional parameters
|
||||||
|
number_of_images=1,
|
||||||
|
language="en",
|
||||||
|
# You can't use a seed value and watermark at the same time.
|
||||||
|
# add_watermark=False,
|
||||||
|
# seed=100,
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
safety_filter_level="block_some",
|
||||||
|
person_generation="allow_adult",
|
||||||
|
)
|
||||||
|
|
||||||
|
images[0].save(location=output_file, include_generation_parameters=False)
|
||||||
|
|
||||||
|
# Optional. View the generated image in a notebook.
|
||||||
|
# images[0].show()
|
||||||
|
|
||||||
|
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -75,8 +329,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generat
|
||||||
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Count Tokens API
|
### Count Tokens API
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.generative_models import GenerativeModel
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
|
||||||
|
prompt = "Why is the sky blue?"
|
||||||
|
|
||||||
|
# Prompt tokens count
|
||||||
|
response = model.count_tokens(prompt)
|
||||||
|
print(f"Prompt Token Count: {response.total_tokens}")
|
||||||
|
print(f"Prompt Character Count: {response.total_billable_characters}")
|
||||||
|
|
||||||
|
# Send text to Gemini
|
||||||
|
response = model.generate_content(prompt)
|
||||||
|
|
||||||
|
# Response tokens count
|
||||||
|
usage_metadata = response.usage_metadata
|
||||||
|
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
|
||||||
|
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
|
||||||
|
print(f"Total Token Count: {usage_metadata.total_token_count}")
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -84,10 +416,83 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Tuning API
|
### Tuning API
|
||||||
|
|
||||||
Create Fine Tuning Job
|
Create Fine Tuning Job
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.preview.tuning import sft
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(developer): Update project
|
||||||
|
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||||
|
|
||||||
|
sft_tuning_job = sft.train(
|
||||||
|
source_model="gemini-1.0-pro-002",
|
||||||
|
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Polling for job completion
|
||||||
|
while not sft_tuning_job.has_ended:
|
||||||
|
time.sleep(60)
|
||||||
|
sft_tuning_job.refresh()
|
||||||
|
|
||||||
|
print(sft_tuning_job.tuned_model_name)
|
||||||
|
print(sft_tuning_job.tuned_model_endpoint_name)
|
||||||
|
print(sft_tuning_job.experiment)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/tuningJobs \
|
curl http://localhost:4000/vertex-ai/tuningJobs \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -99,3 +504,7 @@ curl http://localhost:4000/vertex-ai/tuningJobs \
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
|
@ -131,6 +131,56 @@ Expected Response
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Add Streaming Support
|
||||||
|
|
||||||
|
Here's a simple example of returning unix epoch seconds for both completion + streaming use-cases.
|
||||||
|
|
||||||
|
s/o [@Eloy Lafuente](https://github.com/stronk7) for this code example.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
from typing import Iterator, AsyncIterator
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||||
|
from litellm import CustomLLM, completion, acompletion
|
||||||
|
|
||||||
|
class UnixTimeLLM(CustomLLM):
|
||||||
|
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
return completion(
|
||||||
|
model="test/unixtime",
|
||||||
|
mock_response=str(int(time.time())),
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
return await acompletion(
|
||||||
|
model="test/unixtime",
|
||||||
|
mock_response=str(int(time.time())),
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": str(int(time.time())),
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
return generic_streaming_chunk # type: ignore
|
||||||
|
|
||||||
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": str(int(time.time())),
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
yield generic_streaming_chunk # type: ignore
|
||||||
|
|
||||||
|
unixtime = UnixTimeLLM()
|
||||||
|
```
|
||||||
|
|
||||||
## Custom Handler Spec
|
## Custom Handler Spec
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -661,6 +661,7 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
|
||||||
## Specifying Safety Settings
|
## Specifying Safety Settings
|
||||||
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
|
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
|
||||||
|
|
||||||
|
### Set per model/request
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
||||||
|
@ -752,6 +753,65 @@ response = client.chat.completions.create(
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
### Set Globally
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm.set_verbose = True 👈 See RAW REQUEST/RESPONSE
|
||||||
|
|
||||||
|
litellm.vertex_ai_safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/gemini-pro",
|
||||||
|
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="Proxy">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-experimental
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/gemini-experimental
|
||||||
|
vertex_project: litellm-epic
|
||||||
|
vertex_location: us-central1
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
vertex_ai_safety_settings:
|
||||||
|
- category: HARM_CATEGORY_HARASSMENT
|
||||||
|
threshold: BLOCK_NONE
|
||||||
|
- category: HARM_CATEGORY_HATE_SPEECH
|
||||||
|
threshold: BLOCK_NONE
|
||||||
|
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||||
|
threshold: BLOCK_NONE
|
||||||
|
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||||
|
threshold: BLOCK_NONE
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Set Vertex Project & Vertex Location
|
## Set Vertex Project & Vertex Location
|
||||||
All calls using Vertex AI require the following parameters:
|
All calls using Vertex AI require the following parameters:
|
||||||
* Your Project ID
|
* Your Project ID
|
||||||
|
@ -1450,7 +1510,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||||
|
|
||||||
|
|
||||||
## Embedding Models
|
## **Embedding Models**
|
||||||
|
|
||||||
#### Usage - Embedding
|
#### Usage - Embedding
|
||||||
```python
|
```python
|
||||||
|
@ -1504,7 +1564,158 @@ response = litellm.embedding(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Image Generation Models
|
## **Multi-Modal Embeddings**
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM PROXY (Unified Endpoint)">
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: multimodalembedding@001
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/multimodalembedding@001
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request use OpenAI Python SDK
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
input = None,
|
||||||
|
extra_body = {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
default_vertex_config:
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request use OpenAI Python SDK
|
||||||
|
|
||||||
|
```python
|
||||||
|
import vertexai
|
||||||
|
|
||||||
|
from vertexai.vision_models import Image, MultiModalEmbeddingModel, Video
|
||||||
|
from vertexai.vision_models import VideoSegmentConfig
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers['Authorization'] = f'Bearer {self.token}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials = credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||||
|
image = Image.load_from_file(
|
||||||
|
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = model.get_embeddings(
|
||||||
|
image=image,
|
||||||
|
contextual_text="Colosseum",
|
||||||
|
dimension=1408,
|
||||||
|
)
|
||||||
|
print(f"Image Embedding: {embeddings.image_embedding}")
|
||||||
|
print(f"Text Embedding: {embeddings.text_embedding}")
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## **Image Generation Models**
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
|
|
||||||
|
|
|
@ -728,6 +728,7 @@ general_settings:
|
||||||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||||
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
||||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||||
|
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
|
||||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||||
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param
|
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param
|
||||||
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
||||||
|
|
|
@ -101,8 +101,38 @@ Requirements:
|
||||||
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<TabItem value="key" label="Set on Key">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"tags": ["tag1", "tag2", "tag3"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="team" label="Set on Team">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"tags": ["tag1", "tag2", "tag3"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
Set `extra_body={"metadata": { }}` to `metadata` you want to pass
|
Set `extra_body={"metadata": { }}` to `metadata` you want to pass
|
||||||
|
@ -270,7 +300,42 @@ Requirements:
|
||||||
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<TabItem value="key" label="Set on Key">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"spend_logs_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="team" label="Set on Team">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"spend_logs_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,51 @@ litellm_settings:
|
||||||
|
|
||||||
Removes any field with `user_api_key_*` from metadata.
|
Removes any field with `user_api_key_*` from metadata.
|
||||||
|
|
||||||
|
## What gets logged?
|
||||||
|
|
||||||
|
Found under `kwargs["standard_logging_payload"]`. This is a standard payload, logged for every response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class StandardLoggingPayload(TypedDict):
|
||||||
|
id: str
|
||||||
|
call_type: str
|
||||||
|
response_cost: float
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
startTime: float
|
||||||
|
endTime: float
|
||||||
|
completionStartTime: float
|
||||||
|
model_map_information: StandardLoggingModelInformation
|
||||||
|
model: str
|
||||||
|
model_id: Optional[str]
|
||||||
|
model_group: Optional[str]
|
||||||
|
api_base: str
|
||||||
|
metadata: StandardLoggingMetadata
|
||||||
|
cache_hit: Optional[bool]
|
||||||
|
cache_key: Optional[str]
|
||||||
|
saved_cache_cost: Optional[float]
|
||||||
|
request_tags: list
|
||||||
|
end_user: Optional[str]
|
||||||
|
requester_ip_address: Optional[str]
|
||||||
|
messages: Optional[Union[str, list, dict]]
|
||||||
|
response: Optional[Union[str, list, dict]]
|
||||||
|
model_parameters: dict
|
||||||
|
hidden_params: StandardLoggingHiddenParams
|
||||||
|
|
||||||
|
class StandardLoggingHiddenParams(TypedDict):
|
||||||
|
model_id: Optional[str]
|
||||||
|
cache_key: Optional[str]
|
||||||
|
api_base: Optional[str]
|
||||||
|
response_cost: Optional[str]
|
||||||
|
additional_headers: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingModelInformation(TypedDict):
|
||||||
|
model_map_key: str
|
||||||
|
model_map_value: Optional[ModelInfo]
|
||||||
|
```
|
||||||
|
|
||||||
## Logging Proxy Input/Output - Langfuse
|
## Logging Proxy Input/Output - Langfuse
|
||||||
|
|
||||||
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
|
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
|
||||||
|
|
|
@ -334,3 +334,4 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
```
|
```
|
||||||
Key=... over available RPM=0. Model RPM=100, Active keys=None
|
Key=... over available RPM=0. Model RPM=100, Active keys=None
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -488,9 +488,34 @@ You can set:
|
||||||
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<TabItem value="per-team" label="Per Team">
|
||||||
|
|
||||||
|
Use `/team/new` or `/team/update`, to persist rate limits across multiple keys for a team.
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/team/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{"team_id": "my-prod-team", "max_parallel_requests": 10, "tpm_limit": 20, "rpm_limit": 4}'
|
||||||
|
```
|
||||||
|
|
||||||
|
[**See Swagger**](https://litellm-api.up.railway.app/#/team%20management/new_team_team_new_post)
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"key": "sk-sA7VDkyhlQ7m8Gt77Mbt3Q",
|
||||||
|
"expires": "2024-01-19T01:21:12.816168",
|
||||||
|
"team_id": "my-prod-team",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
<TabItem value="per-user" label="Per Internal User">
|
<TabItem value="per-user" label="Per Internal User">
|
||||||
|
|
||||||
Use `/user/new`, to persist rate limits across multiple keys.
|
Use `/user/new` or `/user/update`, to persist rate limits across multiple keys for internal users.
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -653,6 +678,70 @@ curl --location 'http://localhost:4000/chat/completions' \
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Set default budget for ALL internal users
|
||||||
|
|
||||||
|
Use this to set a default budget for users who you give keys to.
|
||||||
|
|
||||||
|
This will apply when a user has [`user_role="internal_user"`](./self_serve.md#available-roles) (set this via `/user/new` or `/user/update`).
|
||||||
|
|
||||||
|
This will NOT apply if a key has a team_id (team budgets will apply then). [Tell us how we can improve this!](https://github.com/BerriAI/litellm/issues)
|
||||||
|
|
||||||
|
1. Define max budget in your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "gpt-3.5-turbo"
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
max_internal_user_budget: 0 # amount in USD
|
||||||
|
internal_user_budget_duration: "1mo" # reset every month
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create key for user
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
...
|
||||||
|
"key": "sk-X53RdxnDhzamRwjKXR4IHg"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-X53RdxnDhzamRwjKXR4IHg' \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "ExceededBudget: User=<user_id> over budget. Spend=3.7e-05, Budget=0.0",
|
||||||
|
"type": "budget_exceeded",
|
||||||
|
"param": null,
|
||||||
|
"code": "400"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
## Grant Access to new model
|
## Grant Access to new model
|
||||||
|
|
||||||
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.).
|
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.).
|
||||||
|
|
8
litellm-js/spend-logs/package-lock.json
generated
8
litellm-js/spend-logs/package-lock.json
generated
|
@ -6,7 +6,7 @@
|
||||||
"": {
|
"": {
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@hono/node-server": "^1.10.1",
|
"@hono/node-server": "^1.10.1",
|
||||||
"hono": "^4.2.7"
|
"hono": "^4.5.8"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/node": "^20.11.17",
|
"@types/node": "^20.11.17",
|
||||||
|
@ -463,9 +463,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/hono": {
|
"node_modules/hono": {
|
||||||
"version": "4.2.7",
|
"version": "4.5.8",
|
||||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
|
"resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
|
||||||
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
|
"integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=16.0.0"
|
"node": ">=16.0.0"
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@hono/node-server": "^1.10.1",
|
"@hono/node-server": "^1.10.1",
|
||||||
"hono": "^4.2.7"
|
"hono": "^4.5.8"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/node": "^20.11.17",
|
"@types/node": "^20.11.17",
|
||||||
|
|
|
@ -339,6 +339,7 @@ api_version = None
|
||||||
organization = None
|
organization = None
|
||||||
project = None
|
project = None
|
||||||
config_path = None
|
config_path = None
|
||||||
|
vertex_ai_safety_settings: Optional[dict] = None
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
open_ai_text_completion_models: List = []
|
open_ai_text_completion_models: List = []
|
||||||
|
|
|
@ -98,6 +98,10 @@ class LangsmithLogger(CustomLogger):
|
||||||
project_name = metadata.get("project_name", self.langsmith_project)
|
project_name = metadata.get("project_name", self.langsmith_project)
|
||||||
run_name = metadata.get("run_name", self.langsmith_default_run_name)
|
run_name = metadata.get("run_name", self.langsmith_default_run_name)
|
||||||
run_id = metadata.get("id", None)
|
run_id = metadata.get("id", None)
|
||||||
|
parent_run_id = metadata.get("parent_run_id", None)
|
||||||
|
trace_id = metadata.get("trace_id", None)
|
||||||
|
session_id = metadata.get("session_id", None)
|
||||||
|
dotted_order = metadata.get("dotted_order", None)
|
||||||
tags = metadata.get("tags", []) or []
|
tags = metadata.get("tags", []) or []
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||||
|
@ -149,6 +153,18 @@ class LangsmithLogger(CustomLogger):
|
||||||
if run_id:
|
if run_id:
|
||||||
data["id"] = run_id
|
data["id"] = run_id
|
||||||
|
|
||||||
|
if parent_run_id:
|
||||||
|
data["parent_run_id"] = parent_run_id
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
data["trace_id"] = trace_id
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
data["session_id"] = session_id
|
||||||
|
|
||||||
|
if dotted_order:
|
||||||
|
data["dotted_order"] = dotted_order
|
||||||
|
|
||||||
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
|
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -524,6 +524,7 @@ class Logging:
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
],
|
],
|
||||||
|
cache_hit: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Calculate response cost using result + logging object variables.
|
Calculate response cost using result + logging object variables.
|
||||||
|
@ -535,10 +536,13 @@ class Logging:
|
||||||
litellm_params=self.litellm_params
|
litellm_params=self.litellm_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cache_hit is None:
|
||||||
|
cache_hit = self.model_call_details.get("cache_hit", False)
|
||||||
|
|
||||||
response_cost = litellm.response_cost_calculator(
|
response_cost = litellm.response_cost_calculator(
|
||||||
response_object=result,
|
response_object=result,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
cache_hit=cache_hit,
|
||||||
custom_llm_provider=self.model_call_details.get(
|
custom_llm_provider=self.model_call_details.get(
|
||||||
"custom_llm_provider", None
|
"custom_llm_provider", None
|
||||||
),
|
),
|
||||||
|
@ -630,6 +634,7 @@ class Logging:
|
||||||
init_response_obj=result,
|
init_response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
logging_obj=self,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return start_time, end_time, result
|
return start_time, end_time, result
|
||||||
|
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
|
||||||
init_response_obj: Any,
|
init_response_obj: Any,
|
||||||
start_time: dt_object,
|
start_time: dt_object,
|
||||||
end_time: dt_object,
|
end_time: dt_object,
|
||||||
|
logging_obj: Logging,
|
||||||
) -> Optional[StandardLoggingPayload]:
|
) -> Optional[StandardLoggingPayload]:
|
||||||
try:
|
try:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
|
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
|
||||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||||
else:
|
else:
|
||||||
cache_key = None
|
cache_key = None
|
||||||
|
|
||||||
|
saved_cache_cost: Optional[float] = None
|
||||||
if cache_hit is True:
|
if cache_hit is True:
|
||||||
import time
|
import time
|
||||||
|
|
||||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
|
|
||||||
|
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||||
|
result=init_response_obj, cache_hit=False
|
||||||
|
)
|
||||||
|
|
||||||
## Get model cost information ##
|
## Get model cost information ##
|
||||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||||
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||||
|
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
|
saved_cache_cost=saved_cache_cost,
|
||||||
startTime=start_time_float,
|
startTime=start_time_float,
|
||||||
endTime=end_time_float,
|
endTime=end_time_float,
|
||||||
completionStartTime=completion_start_time_float,
|
completionStartTime=completion_start_time_float,
|
||||||
|
|
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"meta.llama3-70b-instruct-v1:0",
|
||||||
"mistral.mistral-large-2407-v1:0",
|
"mistral.mistral-large-2407-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
acompletion: bool,
|
acompletion: bool,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
for k, v in inference_params.items():
|
for k, v in inference_params.items():
|
||||||
if (
|
if (
|
||||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
for key in additional_request_keys:
|
for key in additional_request_keys:
|
||||||
inference_params.pop(key, None)
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
llm_provider="bedrock_converse",
|
|
||||||
)
|
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
inference_params.pop("tools", [])
|
inference_params.pop("tools", [])
|
||||||
)
|
)
|
||||||
|
|
|
@ -124,12 +124,14 @@ class CohereConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -144,11 +146,12 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
headers: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
prompt = " ".join(message["content"] for message in messages)
|
||||||
|
@ -338,13 +341,14 @@ def embedding(
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
aembedding: Optional[bool] = None,
|
aembedding: Optional[bool] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
embed_url = "https://api.cohere.ai/v1/embed"
|
embed_url = "https://api.cohere.ai/v1/embed"
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "texts": input, **optional_params}
|
data = {"model": model, "texts": input, **optional_params}
|
||||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -203,13 +205,14 @@ def completion(
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||||
|
|
|
@ -4,14 +4,17 @@ import traceback
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||||
|
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||||
|
|
||||||
|
|
||||||
class OllamaError(Exception):
|
class OllamaError(Exception):
|
||||||
|
@ -175,7 +178,7 @@ class OllamaChatConfig:
|
||||||
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(
|
model_info = litellm.get_model_info(
|
||||||
model=model, custom_llm_provider="ollama_chat"
|
model=model, custom_llm_provider="ollama"
|
||||||
)
|
)
|
||||||
if model_info.get("supports_function_calling") is True:
|
if model_info.get("supports_function_calling") is True:
|
||||||
optional_params["tools"] = value
|
optional_params["tools"] = value
|
||||||
|
@ -237,13 +240,30 @@ def get_ollama_response(
|
||||||
function_name = optional_params.pop("function_name", None)
|
function_name = optional_params.pop("function_name", None)
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if "role" in m and m["role"] == "tool":
|
if isinstance(
|
||||||
m["role"] = "assistant"
|
m, BaseModel
|
||||||
|
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||||
|
m = m.model_dump(exclude_none=True)
|
||||||
|
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
|
||||||
|
new_tools: List[OllamaToolCall] = []
|
||||||
|
for tool in m["tool_calls"]:
|
||||||
|
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||||
|
if typed_tool["type"] == "function":
|
||||||
|
ollama_tool_call = OllamaToolCall(
|
||||||
|
function=OllamaToolCallFunction(
|
||||||
|
name=typed_tool["function"]["name"],
|
||||||
|
arguments=json.loads(typed_tool["function"]["arguments"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_tools.append(ollama_tool_call)
|
||||||
|
m["tool_calls"] = new_tools
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": new_messages,
|
||||||
"options": optional_params,
|
"options": optional_params,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
@ -263,7 +283,7 @@ def get_ollama_response(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if stream == True:
|
if stream is True:
|
||||||
response = ollama_async_streaming(
|
response = ollama_async_streaming(
|
||||||
url=url,
|
url=url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -283,7 +303,7 @@ def get_ollama_response(
|
||||||
function_name=function_name,
|
function_name=function_name,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
elif stream == True:
|
elif stream is True:
|
||||||
return ollama_completion_stream(
|
return ollama_completion_stream(
|
||||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,6 +84,8 @@ class MistralConfig:
|
||||||
|
|
||||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||||
|
|
||||||
|
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
||||||
|
|
||||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||||
|
|
||||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||||
|
@ -99,6 +101,7 @@ class MistralConfig:
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
safe_prompt: Optional[bool] = None
|
safe_prompt: Optional[bool] = None
|
||||||
response_format: Optional[dict] = None
|
response_format: Optional[dict] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -110,6 +113,7 @@ class MistralConfig:
|
||||||
random_seed: Optional[int] = None,
|
random_seed: Optional[int] = None,
|
||||||
safe_prompt: Optional[bool] = None,
|
safe_prompt: Optional[bool] = None,
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[dict] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals().copy()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -143,6 +147,7 @@ class MistralConfig:
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"seed",
|
"seed",
|
||||||
|
"stop",
|
||||||
"response_format",
|
"response_format",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -166,6 +171,8 @@ class MistralConfig:
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
if param == "tool_choice" and isinstance(value, str):
|
if param == "tool_choice" and isinstance(value, str):
|
||||||
optional_params["tool_choice"] = self._map_tool_choice(
|
optional_params["tool_choice"] = self._map_tool_choice(
|
||||||
tool_choice=value
|
tool_choice=value
|
||||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
||||||
|
|
||||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||||
|
|
||||||
|
# used to interweave user messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
|
||||||
def map_system_message_pt(messages: list) -> list:
|
def map_system_message_pt(messages: list) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
||||||
messages: List,
|
messages: List,
|
||||||
model: str,
|
model: str,
|
||||||
llm_provider: str,
|
llm_provider: str,
|
||||||
|
user_continue_message: Optional[dict] = None,
|
||||||
) -> List[BedrockMessageBlock]:
|
) -> List[BedrockMessageBlock]:
|
||||||
"""
|
"""
|
||||||
Converts given messages from OpenAI format to Bedrock format
|
Converts given messages from OpenAI format to Bedrock format
|
||||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
||||||
|
|
||||||
contents: List[BedrockMessageBlock] = []
|
contents: List[BedrockMessageBlock] = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
|
||||||
|
# if initial message is assistant message
|
||||||
|
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.insert(0, user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
|
# if final message is assistant message
|
||||||
|
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.append(user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content: List[BedrockContentBlock] = []
|
user_content: List[BedrockContentBlock] = []
|
||||||
init_msg_i = msg_i
|
init_msg_i = msg_i
|
||||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import types
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -38,12 +38,15 @@ from litellm.types.llms.vertex_ai import (
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
GenerateContentResponseBody,
|
GenerateContentResponseBody,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
|
Instance,
|
||||||
|
InstanceVideo,
|
||||||
PartType,
|
PartType,
|
||||||
RequestBody,
|
RequestBody,
|
||||||
SafetSettingsConfig,
|
SafetSettingsConfig,
|
||||||
SystemInstructions,
|
SystemInstructions,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
Tools,
|
Tools,
|
||||||
|
VertexMultimodalEmbeddingRequest,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
@ -188,9 +191,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
elif value["type"] == "text": # type: ignore
|
elif value["type"] == "text": # type: ignore
|
||||||
optional_params["response_mime_type"] = "text/plain"
|
optional_params["response_mime_type"] = "text/plain"
|
||||||
if "response_schema" in value: # type: ignore
|
if "response_schema" in value: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
||||||
elif value["type"] == "json_schema": # type: ignore
|
elif value["type"] == "json_schema": # type: ignore
|
||||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||||
if param == "tools" and isinstance(value, list):
|
if param == "tools" and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
|
@ -400,9 +405,11 @@ class VertexGeminiConfig:
|
||||||
elif value["type"] == "text":
|
elif value["type"] == "text":
|
||||||
optional_params["response_mime_type"] = "text/plain"
|
optional_params["response_mime_type"] = "text/plain"
|
||||||
if "response_schema" in value:
|
if "response_schema" in value:
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["response_schema"]
|
optional_params["response_schema"] = value["response_schema"]
|
||||||
elif value["type"] == "json_schema": # type: ignore
|
elif value["type"] == "json_schema": # type: ignore
|
||||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||||
if param == "frequency_penalty":
|
if param == "frequency_penalty":
|
||||||
optional_params["frequency_penalty"] = value
|
optional_params["frequency_penalty"] = value
|
||||||
|
@ -594,6 +601,10 @@ class VertexLLM(BaseLLM):
|
||||||
self._credentials: Optional[Any] = None
|
self._credentials: Optional[Any] = None
|
||||||
self.project_id: Optional[str] = None
|
self.project_id: Optional[str] = None
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
||||||
|
"multimodalembedding",
|
||||||
|
"multimodalembedding@001",
|
||||||
|
]
|
||||||
|
|
||||||
def _process_response(
|
def _process_response(
|
||||||
self,
|
self,
|
||||||
|
@ -1537,6 +1548,160 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def multimodal_embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
print_verbose,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
encoding=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
|
aembedding=False,
|
||||||
|
timeout=300,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||||
|
|
||||||
|
auth_header, _ = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
|
||||||
|
request_data = VertexMultimodalEmbeddingRequest()
|
||||||
|
|
||||||
|
if "instances" in optional_params:
|
||||||
|
request_data["instances"] = optional_params["instances"]
|
||||||
|
elif isinstance(input, list):
|
||||||
|
request_data["instances"] = input
|
||||||
|
else:
|
||||||
|
# construct instances
|
||||||
|
vertex_request_instance = Instance(**optional_params)
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
vertex_request_instance["text"] = input
|
||||||
|
|
||||||
|
request_data["instances"] = [vertex_request_instance]
|
||||||
|
|
||||||
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=[],
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=[],
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if aembedding is True:
|
||||||
|
return self.async_multimodal_embedding(
|
||||||
|
model=model,
|
||||||
|
api_base=url,
|
||||||
|
data=request_data,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
client=client,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = sync_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
if "predictions" not in _json_response:
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
model_response.data = _predictions
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
async def async_multimodal_embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_base: str,
|
||||||
|
data: VertexMultimodalEmbeddingRequest,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
if "predictions" not in _json_response:
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
model_response.data = _predictions
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
|
|
|
@ -943,6 +943,7 @@ def completion(
|
||||||
output_cost_per_token=output_cost_per_token,
|
output_cost_per_token=output_cost_per_token,
|
||||||
cooldown_time=cooldown_time,
|
cooldown_time=cooldown_time,
|
||||||
text_completion=kwargs.get("text_completion"),
|
text_completion=kwargs.get("text_completion"),
|
||||||
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1634,6 +1635,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/generate"
|
or "https://api.cohere.ai/v1/generate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere.completion(
|
model_response = cohere.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1644,6 +1652,7 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
headers=headers,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
@ -1674,6 +1683,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/chat"
|
or "https://api.cohere.ai/v1/chat"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere_chat.completion(
|
model_response = cohere_chat.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1682,6 +1698,7 @@ def completion(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
|
@ -2288,7 +2305,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
@ -2464,7 +2481,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
if acompletion is True or optional_params.get("stream", False) == True:
|
if acompletion is True or optional_params.get("stream", False) is True:
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
response = generator
|
response = generator
|
||||||
|
@ -3158,6 +3175,7 @@ def embedding(
|
||||||
encoding_format = kwargs.get("encoding_format", None)
|
encoding_format = kwargs.get("encoding_format", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
@ -3229,6 +3247,7 @@ def embedding(
|
||||||
"model_config",
|
"model_config",
|
||||||
"cooldown_time",
|
"cooldown_time",
|
||||||
"tags",
|
"tags",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3292,7 +3311,7 @@ def embedding(
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure is True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
@ -3398,12 +3417,18 @@ def embedding(
|
||||||
or get_secret("CO_API_KEY")
|
or get_secret("CO_API_KEY")
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||||
|
headers = extra_headers
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
response = cohere.embedding(
|
response = cohere.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key, # type: ignore
|
api_key=cohere_key, # type: ignore
|
||||||
|
headers=headers,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
@ -3477,6 +3502,26 @@ def embedding(
|
||||||
or get_secret("VERTEX_CREDENTIALS")
|
or get_secret("VERTEX_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"image" in optional_params
|
||||||
|
or "video" in optional_params
|
||||||
|
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
|
||||||
|
):
|
||||||
|
# multimodal embedding is supported on vertex httpx
|
||||||
|
response = vertex_chat_completion.multimodal_embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
aembedding=aembedding,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = vertex_ai.embedding(
|
response = vertex_ai.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -3,9 +3,11 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "*"
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
cache: True
|
success_callback: ["s3"]
|
||||||
cache_params:
|
cache: true
|
||||||
type: redis
|
s3_callback_params:
|
||||||
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}]
|
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
@ -21,6 +21,13 @@ else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMTeamRoles(enum.Enum):
|
||||||
|
# team admin
|
||||||
|
TEAM_ADMIN = "admin"
|
||||||
|
# team member
|
||||||
|
TEAM_MEMBER = "user"
|
||||||
|
|
||||||
|
|
||||||
class LitellmUserRoles(str, enum.Enum):
|
class LitellmUserRoles(str, enum.Enum):
|
||||||
"""
|
"""
|
||||||
Admin Roles:
|
Admin Roles:
|
||||||
|
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
+ sso_only_routes
|
+ sso_only_routes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self_managed_routes: List = [
|
||||||
|
"/team/member_add",
|
||||||
|
"/team/member_delete",
|
||||||
|
] # routes that manage their own allowed/disallowed logic
|
||||||
|
|
||||||
|
|
||||||
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
||||||
# """
|
# """
|
||||||
|
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
team_model_aliases: Optional[Dict] = None
|
team_model_aliases: Optional[Dict] = None
|
||||||
team_member_spend: Optional[float] = None
|
team_member_spend: Optional[float] = None
|
||||||
|
team_member: Optional[Member] = None
|
||||||
team_metadata: Optional[Dict] = None
|
team_metadata: Optional[Dict] = None
|
||||||
|
|
||||||
# End User Params
|
# End User Params
|
||||||
|
|
|
@ -975,8 +975,6 @@ async def user_api_key_auth(
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
if is_llm_api_route(route=route):
|
if is_llm_api_route(route=route):
|
||||||
pass
|
pass
|
||||||
elif is_llm_api_route(route=request["route"].name):
|
|
||||||
pass
|
|
||||||
elif (
|
elif (
|
||||||
route in LiteLLMRoutes.info_routes.value
|
route in LiteLLMRoutes.info_routes.value
|
||||||
): # check if user allowed to call an info route
|
): # check if user allowed to call an info route
|
||||||
|
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||||
and route in LiteLLMRoutes.internal_user_routes.value
|
and route in LiteLLMRoutes.internal_user_routes.value
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
elif (
|
||||||
|
route in LiteLLMRoutes.self_managed_routes.value
|
||||||
|
): # routes that manage their own allowed/disallowed logic
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
user_role = "unknown"
|
user_role = "unknown"
|
||||||
user_id = "unknown"
|
user_id = "unknown"
|
||||||
|
|
|
@ -120,6 +120,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||||
if max_parallel_requests is None:
|
if max_parallel_requests is None:
|
||||||
max_parallel_requests = sys.maxsize
|
max_parallel_requests = sys.maxsize
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||||
"global_max_parallel_requests", None
|
"global_max_parallel_requests", None
|
||||||
)
|
)
|
||||||
|
|
|
@ -119,6 +119,7 @@ async def new_user(
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.send_invite_email is True:
|
if data.send_invite_email is True:
|
||||||
|
@ -732,7 +733,7 @@ async def delete_user(
|
||||||
delete user and associated user keys
|
delete user and associated user keys
|
||||||
|
|
||||||
```
|
```
|
||||||
curl --location 'http://0.0.0.0:8000/team/delete' \
|
curl --location 'http://0.0.0.0:8000/user/delete' \
|
||||||
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
|
||||||
|
|
|
@ -849,7 +849,7 @@ async def generate_key_helper_fn(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from litellm.proxy._types import (
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
|
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_team_admin(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||||
|
) -> bool:
|
||||||
|
for member in team_obj.members_with_roles:
|
||||||
|
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/team/new",
|
"/team/new",
|
||||||
|
@ -417,6 +427,7 @@ async def team_member_add(
|
||||||
|
|
||||||
If user doesn't exist, new user row will also be added to User Table
|
If user doesn't exist, new user row will also be added to User Table
|
||||||
|
|
||||||
|
Only proxy_admin or admin of team, allowed to access this endpoint.
|
||||||
```
|
```
|
||||||
|
|
||||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||||
|
@ -465,6 +476,25 @@ async def team_member_add(
|
||||||
|
|
||||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(user_api_key_dict, "user_role")
|
||||||
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_add",
|
||||||
|
complete_team_data.team_id,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(data.member, Member):
|
if isinstance(data.member, Member):
|
||||||
# add to team db
|
# add to team db
|
||||||
new_member = data.member
|
new_member = data.member
|
||||||
|
@ -569,6 +599,23 @@ async def team_member_delete(
|
||||||
)
|
)
|
||||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_delete", existing_team_row.team_id
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
## DELETE MEMBER FROM TEAM
|
## DELETE MEMBER FROM TEAM
|
||||||
new_team_members: List[Member] = []
|
new_team_members: List[Member] = []
|
||||||
for m in existing_team_row.members_with_roles:
|
for m in existing_team_row.members_with_roles:
|
||||||
|
|
|
@ -266,7 +266,7 @@ def management_endpoint_wrapper(func):
|
||||||
)
|
)
|
||||||
|
|
||||||
_http_request: Request = kwargs.get("http_request")
|
_http_request: Request = kwargs.get("http_request")
|
||||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||||
if parent_otel_span is not None:
|
if parent_otel_span is not None:
|
||||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ def management_endpoint_wrapper(func):
|
||||||
user_api_key_dict: UserAPIKeyAuth = (
|
user_api_key_dict: UserAPIKeyAuth = (
|
||||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||||
)
|
)
|
||||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||||
if parent_otel_span is not None:
|
if parent_otel_span is not None:
|
||||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||||
|
|
||||||
|
|
|
@ -301,16 +301,19 @@ async def pass_through_request(
|
||||||
request=request, headers=headers, forward_headers=forward_headers
|
request=request, headers=headers, forward_headers=forward_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_parsed_body = None
|
||||||
if custom_body:
|
if custom_body:
|
||||||
_parsed_body = custom_body
|
_parsed_body = custom_body
|
||||||
else:
|
else:
|
||||||
request_body = await request.body()
|
request_body = await request.body()
|
||||||
|
if request_body == b"" or request_body is None:
|
||||||
|
_parsed_body = None
|
||||||
|
else:
|
||||||
body_str = request_body.decode()
|
body_str = request_body.decode()
|
||||||
try:
|
try:
|
||||||
_parsed_body = ast.literal_eval(body_str)
|
_parsed_body = ast.literal_eval(body_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
_parsed_body = json.loads(body_str)
|
_parsed_body = json.loads(body_str)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||||
url, headers, _parsed_body
|
url, headers, _parsed_body
|
||||||
|
@ -320,7 +323,7 @@ async def pass_through_request(
|
||||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||||
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
data=_parsed_body,
|
data=_parsed_body or {},
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -360,11 +363,20 @@ async def pass_through_request(
|
||||||
|
|
||||||
# combine url with query params for logging
|
# combine url with query params for logging
|
||||||
|
|
||||||
requested_query_params = query_params or request.query_params.__dict__
|
requested_query_params: Optional[dict] = (
|
||||||
|
query_params or request.query_params.__dict__
|
||||||
|
)
|
||||||
|
if requested_query_params == request.query_params.__dict__:
|
||||||
|
requested_query_params = None
|
||||||
|
|
||||||
|
requested_query_params_str = None
|
||||||
|
if requested_query_params:
|
||||||
requested_query_params_str = "&".join(
|
requested_query_params_str = "&".join(
|
||||||
f"{k}={v}" for k, v in requested_query_params.items()
|
f"{k}={v}" for k, v in requested_query_params.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging_url = str(url)
|
||||||
|
if requested_query_params_str:
|
||||||
if "?" in str(url):
|
if "?" in str(url):
|
||||||
logging_url = str(url) + "&" + requested_query_params_str
|
logging_url = str(url) + "&" + requested_query_params_str
|
||||||
else:
|
else:
|
||||||
|
@ -409,6 +421,14 @@ async def pass_through_request(
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("request method: {}".format(request.method))
|
||||||
|
verbose_proxy_logger.debug("request url: {}".format(url))
|
||||||
|
verbose_proxy_logger.debug("request headers: {}".format(headers))
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"requested_query_params={}".format(requested_query_params)
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||||
|
|
||||||
response = await async_client.request(
|
response = await async_client.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=url,
|
url=url,
|
||||||
|
|
|
@ -1,20 +1,18 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
- model_name: openai-embedding
|
|
||||||
litellm_params:
|
|
||||||
model: openai/text-embedding-3-small
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
|
|
||||||
litellm_settings:
|
guardrails:
|
||||||
set_verbose: True
|
- guardrail_name: "lakera-pre-guard"
|
||||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
litellm_params:
|
||||||
cache_params:
|
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||||
type: qdrant-semantic
|
mode: "during_call"
|
||||||
qdrant_semantic_cache_embedding_model: openai-embedding
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
qdrant_collection_name: test_collection
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
qdrant_quantization_config: binary
|
category_thresholds:
|
||||||
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
prompt_injection: 0.1
|
||||||
|
jailbreak: 0.1
|
||||||
|
|
|
@ -1498,6 +1498,11 @@ class ProxyConfig:
|
||||||
litellm.get_secret(secret_name=key, default_value=value)
|
litellm.get_secret(secret_name=key, default_value=value)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check if litellm_license in general_settings
|
||||||
|
if "LITELLM_LICENSE" in environment_variables:
|
||||||
|
_license_check.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||||
|
premium_user = _license_check.is_premium()
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get("litellm_settings", None)
|
litellm_settings = config.get("litellm_settings", None)
|
||||||
if litellm_settings is None:
|
if litellm_settings is None:
|
||||||
|
@ -1878,6 +1883,11 @@ class ProxyConfig:
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check if litellm_license in general_settings
|
||||||
|
if "litellm_license" in general_settings:
|
||||||
|
_license_check.license_str = general_settings["litellm_license"]
|
||||||
|
premium_user = _license_check.is_premium()
|
||||||
|
|
||||||
router_params: dict = {
|
router_params: dict = {
|
||||||
"cache_responses": litellm.cache
|
"cache_responses": litellm.cache
|
||||||
!= None, # cache if user passed in cache values
|
!= None, # cache if user passed in cache values
|
||||||
|
@ -2784,11 +2794,14 @@ async def startup_event():
|
||||||
await custom_db_client.connect()
|
await custom_db_client.connect()
|
||||||
|
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
# add master key to db
|
|
||||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||||
litellm_proxy_admin_name = os.getenv(
|
litellm_proxy_admin_name = os.getenv(
|
||||||
"PROXY_ADMIN_ID", litellm_proxy_admin_name
|
"PROXY_ADMIN_ID", litellm_proxy_admin_name
|
||||||
)
|
)
|
||||||
|
if general_settings.get("disable_adding_master_key_hash_to_db") is True:
|
||||||
|
verbose_proxy_logger.info("Skipping writing master key hash to db")
|
||||||
|
else:
|
||||||
|
# add master key to db
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
generate_key_helper_fn(
|
generate_key_helper_fn(
|
||||||
request_type="user",
|
request_type="user",
|
||||||
|
@ -3011,6 +3024,29 @@ async def chat_completion(
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Follows the exact same API spec as `OpenAI's Chat API https://platform.openai.com/docs/api-reference/chat`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:4000/v1/chat/completions \
|
||||||
|
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -3268,6 +3304,24 @@ async def completion(
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Follows the exact same API spec as `OpenAI's Completions API https://platform.openai.com/docs/api-reference/completions`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:4000/v1/completions \
|
||||||
|
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"temperature": 0.7
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
data = {}
|
data = {}
|
||||||
try:
|
try:
|
||||||
|
@ -3474,6 +3528,23 @@ async def embeddings(
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Follows the exact same API spec as `OpenAI's Embeddings API https://platform.openai.com/docs/api-reference/embeddings`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:4000/v1/embeddings \
|
||||||
|
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
|
||||||
|
-d '{
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": "The quick brown fox jumps over the lazy dog"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
data: Any = {}
|
data: Any = {}
|
||||||
try:
|
try:
|
||||||
|
@ -3481,6 +3552,11 @@ async def embeddings(
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
data = orjson.loads(body)
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Request received by LiteLLM:\n%s",
|
||||||
|
json.dumps(data, indent=4),
|
||||||
|
)
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data = await add_litellm_data_to_request(
|
data = await add_litellm_data_to_request(
|
||||||
data=data,
|
data=data,
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -8,12 +10,30 @@ from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||||
from litellm.proxy.utils import hash_token
|
from litellm.proxy.utils import hash_token
|
||||||
|
|
||||||
|
|
||||||
|
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
||||||
|
if _master_key is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
## string comparison
|
||||||
|
is_master_key = secrets.compare_digest(api_key, _master_key)
|
||||||
|
if is_master_key:
|
||||||
|
return True
|
||||||
|
|
||||||
|
## hash comparison
|
||||||
|
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
|
||||||
|
if is_master_key:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_logging_payload(
|
def get_logging_payload(
|
||||||
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
||||||
) -> SpendLogsPayload:
|
) -> SpendLogsPayload:
|
||||||
from pydantic import Json
|
from pydantic import Json
|
||||||
|
|
||||||
from litellm.proxy._types import LiteLLM_SpendLogs
|
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||||
|
from litellm.proxy.proxy_server import general_settings, master_key
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
|
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
|
||||||
|
@ -36,9 +56,15 @@ def get_logging_payload(
|
||||||
usage = dict(usage)
|
usage = dict(usage)
|
||||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
api_key = metadata.get("user_api_key", "")
|
api_key = metadata.get("user_api_key", "")
|
||||||
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
if api_key is not None and isinstance(api_key, str):
|
||||||
|
if api_key.startswith("sk-"):
|
||||||
# hash the api_key
|
# hash the api_key
|
||||||
api_key = hash_token(api_key)
|
api_key = hash_token(api_key)
|
||||||
|
if (
|
||||||
|
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||||
|
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||||
|
):
|
||||||
|
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||||
|
|
||||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||||
_model_group = metadata.get("model_group", "")
|
_model_group = metadata.get("model_group", "")
|
||||||
|
|
21
litellm/proxy/tests/test_vtx_embedding.py
Normal file
21
litellm/proxy/tests/test_vtx_embedding.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
input=[],
|
||||||
|
extra_body={
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
58
litellm/proxy/tests/test_vtx_sdk_embedding.py
Normal file
58
litellm/proxy/tests/test_vtx_sdk_embedding.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
from vertexai.vision_models import (
|
||||||
|
Image,
|
||||||
|
MultiModalEmbeddingModel,
|
||||||
|
Video,
|
||||||
|
VideoSegmentConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||||
|
image = Image.load_from_file(
|
||||||
|
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = model.get_embeddings(
|
||||||
|
image=image,
|
||||||
|
contextual_text="Colosseum",
|
||||||
|
dimension=1408,
|
||||||
|
)
|
||||||
|
print(f"Image Embedding: {embeddings.image_embedding}")
|
||||||
|
print(f"Text Embedding: {embeddings.text_embedding}")
|
|
@ -44,6 +44,7 @@ from litellm.proxy._types import (
|
||||||
DynamoDBArgs,
|
DynamoDBArgs,
|
||||||
LiteLLM_VerificationTokenView,
|
LiteLLM_VerificationTokenView,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
Member,
|
||||||
ResetTeamBudgetRequest,
|
ResetTeamBudgetRequest,
|
||||||
SpendLogsMetadata,
|
SpendLogsMetadata,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
|
@ -1395,6 +1396,7 @@ class PrismaClient:
|
||||||
t.blocked AS team_blocked,
|
t.blocked AS team_blocked,
|
||||||
t.team_alias AS team_alias,
|
t.team_alias AS team_alias,
|
||||||
t.metadata AS team_metadata,
|
t.metadata AS team_metadata,
|
||||||
|
t.members_with_roles AS team_members_with_roles,
|
||||||
tm.spend AS team_member_spend,
|
tm.spend AS team_member_spend,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
@ -1412,6 +1414,33 @@ class PrismaClient:
|
||||||
response["team_models"] = []
|
response["team_models"] = []
|
||||||
if response["team_blocked"] is None:
|
if response["team_blocked"] is None:
|
||||||
response["team_blocked"] = False
|
response["team_blocked"] = False
|
||||||
|
|
||||||
|
team_member: Optional[Member] = None
|
||||||
|
if (
|
||||||
|
response["team_members_with_roles"] is not None
|
||||||
|
and response["user_id"] is not None
|
||||||
|
):
|
||||||
|
## find the team member corresponding to user id
|
||||||
|
"""
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "admin",
|
||||||
|
"user_id": "default_user_id",
|
||||||
|
"user_email": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"user_id": null,
|
||||||
|
"user_email": "test@email.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
for tm in response["team_members_with_roles"]:
|
||||||
|
if tm.get("user_id") is not None and response[
|
||||||
|
"user_id"
|
||||||
|
] == tm.get("user_id"):
|
||||||
|
team_member = Member(**tm)
|
||||||
|
response["team_member"] = team_member
|
||||||
response = LiteLLM_VerificationTokenView(
|
response = LiteLLM_VerificationTokenView(
|
||||||
**response, last_refreshed_at=time.time()
|
**response, last_refreshed_at=time.time()
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
create_pass_through_route,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
default_vertex_config = None
|
default_vertex_config = None
|
||||||
|
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def execute_post_vertex_ai_request(
|
@router.api_route(
|
||||||
|
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||||
|
)
|
||||||
|
async def vertex_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
route: str,
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
|
|
||||||
if default_vertex_config is None:
|
if default_vertex_config is None:
|
||||||
|
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
|
||||||
vertex_project = default_vertex_config.get("vertex_project", None)
|
vertex_project = default_vertex_config.get("vertex_project", None)
|
||||||
vertex_location = default_vertex_config.get("vertex_location", None)
|
vertex_location = default_vertex_config.get("vertex_location", None)
|
||||||
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
||||||
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
|
||||||
request_data_json = {}
|
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||||
body = await request.body()
|
model="",
|
||||||
body_str = body.decode()
|
gemini_api_key=None,
|
||||||
if len(body_str) > 0:
|
vertex_credentials=vertex_credentials,
|
||||||
try:
|
|
||||||
request_data_json = ast.literal_eval(body_str)
|
|
||||||
except:
|
|
||||||
request_data_json = json.loads(body_str)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"Request received by LiteLLM:\n{}".format(
|
|
||||||
json.dumps(request_data_json, indent=4)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = (
|
|
||||||
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
|
|
||||||
request_data=request_data_json,
|
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
vertex_credentials=vertex_credentials,
|
stream=False,
|
||||||
request_route=route,
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
api_base="",
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
headers = {
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
request_route = encoded_endpoint
|
||||||
|
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||||
|
|
||||||
@router.post(
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
|
if not encoded_endpoint.startswith("/"):
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_generate_content(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
|
|
||||||
|
|
||||||
Example Curl:
|
# Construct the full target URL using httpx
|
||||||
```
|
base_url = httpx.URL(base_target_url)
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
|
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
## check for streaming
|
||||||
try:
|
is_streaming_request = False
|
||||||
response = await execute_post_vertex_ai_request(
|
if "stream" in str(updated_url):
|
||||||
request=request,
|
is_streaming_request = True
|
||||||
route=f"/publishers/google/models/{model_id}:generateContent",
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers=headers,
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request,
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
return received_value
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_predict_endpoint(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /predict endpoint
|
|
||||||
Use this for:
|
|
||||||
- Embeddings API - Text Embedding, Multi Modal Embedding
|
|
||||||
- Imagen API
|
|
||||||
- Code Completion API
|
|
||||||
|
|
||||||
Example Curl:
|
|
||||||
```
|
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"instances":[{"content": "gm"}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/publishers/google/models/{model_id}:predict",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_countTokens_endpoint(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
|
|
||||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
|
|
||||||
|
|
||||||
|
|
||||||
Example Curl:
|
|
||||||
```
|
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/publishers/google/models/{model_id}:countTokens",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/batchPredictionJobs",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_batch_prediction_job(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/batchPredictionJobs",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/tuningJobs",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_fine_tuning_job(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/tuningJobs",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_cancel_fine_tuning_job(
|
|
||||||
request: Request,
|
|
||||||
job_id: str,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/tuningJobs/{job_id}:cancel",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/cachedContents",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_add_cached_content(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/cachedContents",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -501,6 +501,8 @@ async def test_async_vertexai_streaming_response():
|
||||||
assert len(complete_response) > 0
|
assert len(complete_response) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.APIConnectionError:
|
||||||
|
pass
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
except litellm.InternalServerError as e:
|
except litellm.InternalServerError as e:
|
||||||
|
@ -955,6 +957,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
||||||
assert isinstance(response._hidden_params["response_cost"], float)
|
assert isinstance(response._hidden_params["response_cost"], float)
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.InternalServerError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "429 Quota exceeded" in str(e):
|
if "429 Quota exceeded" in str(e):
|
||||||
pass
|
pass
|
||||||
|
@ -1004,7 +1008,9 @@ async def test_partner_models_httpx_streaming(model, sync_mode):
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except litellm.RateLimitError:
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except litellm.InternalServerError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "429 Quota exceeded" in str(e):
|
if "429 Quota exceeded" in str(e):
|
||||||
|
@ -1558,6 +1564,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
||||||
"response_schema"
|
"response_schema"
|
||||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
"response_mime_type"
|
||||||
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_call.call_args.kwargs["json"]["generationConfig"][
|
||||||
|
"response_mime_type"
|
||||||
|
]
|
||||||
|
== "application/json"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
"response_schema"
|
"response_schema"
|
||||||
|
@ -1826,6 +1842,71 @@ def test_vertexai_embedding():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_multimodal_embedding():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
|
||||||
|
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.aembedding function
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_vertexai = kwargs["json"]
|
||||||
|
|
||||||
|
print("args to vertex ai call:", args_to_vertexai)
|
||||||
|
|
||||||
|
assert args_to_vertexai == expected_payload
|
||||||
|
assert response.model == "multimodalembedding@001"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
response_data = response.data[0]
|
||||||
|
assert "imageEmbedding" in response_data
|
||||||
|
assert "textEmbedding" in response_data
|
||||||
|
|
||||||
|
# Optional: Print for debugging
|
||||||
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
reason="new test - works locally running into vertex version issues on ci/cd"
|
reason="new test - works locally running into vertex version issues on ci/cd"
|
||||||
)
|
)
|
||||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system},
|
{"role": "system", "content": system},
|
||||||
{"role": "user", "content": "hey, how's it going?"},
|
{"role": "assistant", "content": "hey, how's it going?"},
|
||||||
],
|
],
|
||||||
|
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||||
}
|
}
|
||||||
response: ModelResponse = completion(
|
response: ModelResponse = completion(
|
||||||
model="bedrock/{}".format(model),
|
model="bedrock/{}".format(model),
|
||||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
||||||
response = completion(
|
response = completion(
|
||||||
model="command-r",
|
model="command-r",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1252,3 +1252,48 @@ def test_standard_logging_payload(model, turn_off_message_logging):
|
||||||
]
|
]
|
||||||
if turn_off_message_logging:
|
if turn_off_message_logging:
|
||||||
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
|
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
|
||||||
|
def test_aaastandard_logging_payload_cache_hit():
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
# sync completion
|
||||||
|
|
||||||
|
litellm.cache = Cache()
|
||||||
|
|
||||||
|
_ = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
litellm.success_callback = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
customHandler, "log_success_event", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
_ = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
|
||||||
|
"kwargs"
|
||||||
|
]["standard_logging_object"]
|
||||||
|
|
||||||
|
assert standard_logging_object["cache_hit"] is True
|
||||||
|
assert standard_logging_object["response_cost"] == 0
|
||||||
|
assert standard_logging_object["saved_cache_cost"] > 0
|
||||||
|
|
|
@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
|
||||||
)
|
)
|
||||||
def test_parallel_function_call(model):
|
def test_parallel_function_call(model):
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
# Step 1: send the conversation and available functions to the model
|
# Step 1: send the conversation and available functions to the model
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
@ -141,6 +142,8 @@ def test_parallel_function_call(model):
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
) # get a new response from the model where it can see the function response
|
) # get a new response from the model where it can see the function response
|
||||||
print("second response\n", second_response)
|
print("second response\n", second_response)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -322,6 +325,7 @@ def test_groq_parallel_function_call():
|
||||||
location=function_args.get("location"),
|
location=function_args.get("location"),
|
||||||
unit=function_args.get("unit"),
|
unit=function_args.get("unit"),
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
|
@ -337,27 +341,3 @@ def test_groq_parallel_function_call():
|
||||||
print("second response\n", second_response)
|
print("second response\n", second_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["gemini/gemini-1.5-pro"])
|
|
||||||
def test_simple_function_call_function_param(model):
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
|
||||||
response = completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "plot",
|
|
||||||
"description": "Generate plots",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
tool_choice="auto",
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
|
@ -116,6 +116,8 @@ async def test_async_image_generation_openai():
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
|
except litellm.APIError:
|
||||||
|
pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.ContentPolicyViolationError:
|
except litellm.ContentPolicyViolationError:
|
||||||
|
|
|
@ -2328,6 +2328,11 @@ async def test_master_key_hashing(prisma_client):
|
||||||
from litellm.proxy.proxy_server import user_api_key_cache
|
from litellm.proxy.proxy_server import user_api_key_cache
|
||||||
|
|
||||||
_team_id = "ishaans-special-team_{}".format(uuid.uuid4())
|
_team_id = "ishaans-special-team_{}".format(uuid.uuid4())
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
)
|
||||||
await new_team(
|
await new_team(
|
||||||
NewTeamRequest(team_id=_team_id),
|
NewTeamRequest(team_id=_team_id),
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
@ -2343,7 +2348,8 @@ async def test_master_key_hashing(prisma_client):
|
||||||
models=["azure-gpt-3.5"],
|
models=["azure-gpt-3.5"],
|
||||||
team_id=_team_id,
|
team_id=_team_id,
|
||||||
tpm_limit=20,
|
tpm_limit=20,
|
||||||
)
|
),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
print(_response)
|
print(_response)
|
||||||
assert _response.models == ["azure-gpt-3.5"]
|
assert _response.models == ["azure-gpt-3.5"]
|
||||||
|
|
|
@ -19,7 +19,11 @@ from litellm.types.completion import (
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
from litellm.utils import get_optional_params, get_optional_params_embeddings
|
from litellm.utils import (
|
||||||
|
get_optional_params,
|
||||||
|
get_optional_params_embeddings,
|
||||||
|
get_optional_params_image_gen,
|
||||||
|
)
|
||||||
|
|
||||||
## get_optional_params_embeddings
|
## get_optional_params_embeddings
|
||||||
### Models: OpenAI, Azure, Bedrock
|
### Models: OpenAI, Azure, Bedrock
|
||||||
|
@ -430,7 +434,6 @@ def test_get_optional_params_image_gen():
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
assert "aws_region_name" not in response
|
assert "aws_region_name" not in response
|
||||||
|
|
||||||
response = litellm.utils.get_optional_params_image_gen(
|
response = litellm.utils.get_optional_params_image_gen(
|
||||||
aws_region_name="us-east-1", custom_llm_provider="bedrock"
|
aws_region_name="us-east-1", custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
|
@ -463,3 +466,36 @@ def test_get_optional_params_num_retries():
|
||||||
|
|
||||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
assert mock_client.call_args.kwargs["max_retries"] == 10
|
assert mock_client.call_args.kwargs["max_retries"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider",
|
||||||
|
[
|
||||||
|
"vertex_ai",
|
||||||
|
"vertex_ai_beta",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_vertex_safety_settings(provider):
|
||||||
|
litellm.vertex_ai_safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="gemini-1.5-pro", custom_llm_provider=provider
|
||||||
|
)
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
|
|
@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
|
|
||||||
await team_member_add(
|
await team_member_add(
|
||||||
data=team_member_add_request,
|
data=team_member_add_request,
|
||||||
user_api_key_dict=UserAPIKeyAuth(),
|
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
||||||
|
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
||||||
|
prisma_client, team_member_role, team_route
|
||||||
|
):
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
team_member=Member(role=team_member_role, user_id=user),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
|
||||||
|
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
||||||
|
import json
|
||||||
|
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url=team_route)
|
||||||
|
|
||||||
|
body = {}
|
||||||
|
json_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
## ALLOWED BY USER_API_KEY_AUTH
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||||
|
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin(
|
||||||
|
prisma_client, new_member_method, user_role
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
||||||
|
|
||||||
|
Allow team admins to:
|
||||||
|
- Add and remove team members
|
||||||
|
- raise error if team member not an existing 'internal_user'
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
HTTPException,
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
user_id=user,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
members_with_roles=[Member(role=user_role, user_id=user)],
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
if new_member_method == "user_id":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_id": user}],
|
||||||
|
}
|
||||||
|
elif new_member_method == "user_email":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_email": user}],
|
||||||
|
}
|
||||||
|
team_member_add_request = TeamMemberAddRequest(**data)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_litellm_usertable:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_litellm_usertable.upsert = mock_client
|
||||||
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await team_member_add(
|
||||||
|
data=team_member_add_request,
|
||||||
|
user_api_key_dict=valid_token,
|
||||||
|
http_request=Request(
|
||||||
|
scope={"type": "http", "path": "/user/new"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
if user_role == "user":
|
||||||
|
assert e.status_code == 403
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||||
|
== litellm.max_internal_user_budget
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||||
|
== litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_info_team_list(prisma_client):
|
async def test_user_info_team_list(prisma_client):
|
||||||
"""Assert user_info for admin calls team_list function"""
|
"""Assert user_info for admin calls team_list function"""
|
||||||
|
|
24
litellm/types/llms/ollama.py
Normal file
24
litellm/types/llms/ollama.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import (
|
||||||
|
Protocol,
|
||||||
|
Required,
|
||||||
|
Self,
|
||||||
|
TypeGuard,
|
||||||
|
get_origin,
|
||||||
|
override,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolCallFunction(
|
||||||
|
TypedDict
|
||||||
|
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
|
||||||
|
name: str
|
||||||
|
arguments: dict
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolCall(TypedDict):
|
||||||
|
function: OllamaToolCallFunction
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Protocol,
|
Protocol,
|
||||||
|
@ -305,3 +305,18 @@ class ResponseTuningJob(TypedDict):
|
||||||
]
|
]
|
||||||
createTime: Optional[str]
|
createTime: Optional[str]
|
||||||
updateTime: Optional[str]
|
updateTime: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceVideo(TypedDict, total=False):
|
||||||
|
gcsUri: str
|
||||||
|
videoSegmentConfig: Tuple[float, float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class Instance(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
image: Dict[str, str]
|
||||||
|
video: InstanceVideo
|
||||||
|
|
||||||
|
|
||||||
|
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||||
|
instances: List[Instance]
|
||||||
|
|
|
@ -1116,6 +1116,7 @@ all_litellm_params = [
|
||||||
"cooldown_time",
|
"cooldown_time",
|
||||||
"cache_key",
|
"cache_key",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
|
"user_continue_message",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1218,6 +1219,7 @@ class StandardLoggingPayload(TypedDict):
|
||||||
metadata: StandardLoggingMetadata
|
metadata: StandardLoggingMetadata
|
||||||
cache_hit: Optional[bool]
|
cache_hit: Optional[bool]
|
||||||
cache_key: Optional[str]
|
cache_key: Optional[str]
|
||||||
|
saved_cache_cost: Optional[float]
|
||||||
request_tags: list
|
request_tags: list
|
||||||
end_user: Optional[str]
|
end_user: Optional[str]
|
||||||
requester_ip_address: Optional[str]
|
requester_ip_address: Optional[str]
|
||||||
|
|
|
@ -541,7 +541,7 @@ def function_setup(
|
||||||
call_type == CallTypes.embedding.value
|
call_type == CallTypes.embedding.value
|
||||||
or call_type == CallTypes.aembedding.value
|
or call_type == CallTypes.aembedding.value
|
||||||
):
|
):
|
||||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
messages = args[1] if len(args) > 1 else kwargs.get("input", None)
|
||||||
elif (
|
elif (
|
||||||
call_type == CallTypes.image_generation.value
|
call_type == CallTypes.image_generation.value
|
||||||
or call_type == CallTypes.aimage_generation.value
|
or call_type == CallTypes.aimage_generation.value
|
||||||
|
@ -2323,6 +2323,7 @@ def get_litellm_params(
|
||||||
output_cost_per_second=None,
|
output_cost_per_second=None,
|
||||||
cooldown_time=None,
|
cooldown_time=None,
|
||||||
text_completion=None,
|
text_completion=None,
|
||||||
|
user_continue_message=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2347,6 +2348,7 @@ def get_litellm_params(
|
||||||
"output_cost_per_second": output_cost_per_second,
|
"output_cost_per_second": output_cost_per_second,
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
"text_completion": text_completion,
|
"text_completion": text_completion,
|
||||||
|
"user_continue_message": user_continue_message,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
@ -3145,7 +3147,6 @@ def get_optional_params(
|
||||||
or model in litellm.vertex_embedding_models
|
or model in litellm.vertex_embedding_models
|
||||||
or model in litellm.vertex_vision_models
|
or model in litellm.vertex_vision_models
|
||||||
):
|
):
|
||||||
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
|
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3157,9 +3158,8 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(
|
if litellm.vertex_ai_safety_settings is not None:
|
||||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||||
)
|
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3170,7 +3170,7 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -3185,6 +3185,8 @@ def get_optional_params(
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if litellm.vertex_ai_safety_settings is not None:
|
||||||
|
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||||
):
|
):
|
||||||
|
@ -4219,6 +4221,7 @@ def get_supported_openai_params(
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"stop",
|
"stop",
|
||||||
"n",
|
"n",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "cohere_chat":
|
elif custom_llm_provider == "cohere_chat":
|
||||||
return [
|
return [
|
||||||
|
@ -4233,6 +4236,7 @@ def get_supported_openai_params(
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"seed",
|
"seed",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "maritalk":
|
elif custom_llm_provider == "maritalk":
|
||||||
return [
|
return [
|
||||||
|
@ -7121,6 +7125,14 @@ def exception_type(
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif "A conversation must start with a user message." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"Unable to locate credentials" in error_str
|
"Unable to locate credentials" in error_str
|
||||||
or "The security token included in the request is invalid"
|
or "The security token included in the request is invalid"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.44.1"
|
version = "1.44.2"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.44.1"
|
version = "1.44.2"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue