forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_redis_cluster
This commit is contained in:
commit
68cb5cae58
56 changed files with 2079 additions and 411 deletions
|
@ -282,7 +282,7 @@ jobs:
|
|||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
pip install openai
|
||||
pip install "openai==1.40.0"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
|
|
@ -13,8 +13,9 @@ spec:
|
|||
{{- include "litellm.selectorLabels" . | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
{{- with .Values.podAnnotations }}
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap-litellm.yaml") . | sha256sum }}
|
||||
{{- with .Values.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
|
|
|
@ -81,6 +81,7 @@ Works for:
|
|||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel
|
||||
|
||||
# add to env var
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
|
|
@ -8,6 +8,7 @@ liteLLM supports:
|
|||
|
||||
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
|
||||
- [Langfuse](https://langfuse.com/docs)
|
||||
- [LangSmith](https://www.langchain.com/langsmith)
|
||||
- [Helicone](https://docs.helicone.ai/introduction)
|
||||
- [Traceloop](https://traceloop.com/docs)
|
||||
- [Lunary](https://lunary.ai/docs)
|
||||
|
|
|
@ -56,7 +56,7 @@ response = litellm.completion(
|
|||
```
|
||||
|
||||
## Advanced
|
||||
### Set Langsmith fields - Custom Projec, Run names, tags
|
||||
### Set Langsmith fields
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
@ -77,7 +77,15 @@ response = litellm.completion(
|
|||
metadata={
|
||||
"run_name": "litellmRUN", # langsmith run name
|
||||
"project_name": "litellm-completion", # langsmith project name
|
||||
"tags": ["model1", "prod-2"] # tags to log on langsmith
|
||||
"run_id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", # langsmith run id
|
||||
"parent_run_id": "f8faf8c1-9778-49a4-9004-628cdb0047e5", # langsmith run parent run id
|
||||
"trace_id": "df570c03-5a03-4cea-8df0-c162d05127ac", # langsmith run trace id
|
||||
"session_id": "1ffd059c-17ea-40a8-8aef-70fd0307db82", # langsmith run session id
|
||||
"tags": ["model1", "prod-2"], # langsmith run tags
|
||||
"metadata": { # langsmith run metadata
|
||||
"key1": "value1"
|
||||
},
|
||||
"dotted_order": "20240429T004912090000Z497f6eca-6276-4993-bfeb-53cbbbba6f08"
|
||||
}
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] Vertex AI Endpoints (Pass-Through)
|
||||
|
||||
Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation).
|
||||
Use VertexAI SDK to call endpoints on LiteLLM Gateway (native provider format)
|
||||
|
||||
:::tip
|
||||
|
||||
|
@ -40,16 +44,119 @@ litellm --config /path/to/config.yaml
|
|||
|
||||
#### 3. Test it
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:countTokens \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"instances":[{"content": "gm"}]}'
|
||||
```python
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
model = GenerativeModel("gemini-1.5-flash-001")
|
||||
|
||||
response = model.generate_content(
|
||||
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
|
||||
)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Gemini API (Generate Content)
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="py" label="Vertex Python SDK">
|
||||
|
||||
```python
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
|
||||
)
|
||||
|
||||
model = GenerativeModel("gemini-1.5-flash-001")
|
||||
|
||||
response = model.generate_content(
|
||||
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
|
||||
)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="Curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -57,8 +164,77 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
|
|||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Embeddings API
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="py" label="Vertex Python SDK">
|
||||
|
||||
```python
|
||||
from typing import List, Optional
|
||||
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
|
||||
def embed_text(
|
||||
texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
|
||||
task: str = "RETRIEVAL_DOCUMENT",
|
||||
model_name: str = "text-embedding-004",
|
||||
dimensionality: Optional[int] = 256,
|
||||
) -> List[List[float]]:
|
||||
"""Embeds texts with a pre-trained, foundational model."""
|
||||
model = TextEmbeddingModel.from_pretrained(model_name)
|
||||
inputs = [TextEmbeddingInput(text, task) for text in texts]
|
||||
kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
|
||||
embeddings = model.get_embeddings(inputs, **kwargs)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -66,8 +242,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-geck
|
|||
-d '{"instances":[{"content": "gm"}]}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### Imagen API
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="py" label="Vertex Python SDK">
|
||||
|
||||
```python
|
||||
from typing import List, Optional
|
||||
from vertexai.preview.vision_models import ImageGenerationModel
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
|
||||
|
||||
images = model.generate_images(
|
||||
prompt=prompt,
|
||||
# Optional parameters
|
||||
number_of_images=1,
|
||||
language="en",
|
||||
# You can't use a seed value and watermark at the same time.
|
||||
# add_watermark=False,
|
||||
# seed=100,
|
||||
aspect_ratio="1:1",
|
||||
safety_filter_level="block_some",
|
||||
person_generation="allow_adult",
|
||||
)
|
||||
|
||||
images[0].save(location=output_file, include_generation_parameters=False)
|
||||
|
||||
# Optional. View the generated image in a notebook.
|
||||
# images[0].show()
|
||||
|
||||
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -75,8 +329,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generat
|
|||
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### Count Tokens API
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="py" label="Vertex Python SDK">
|
||||
|
||||
```python
|
||||
from typing import List, Optional
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
|
||||
model = GenerativeModel("gemini-1.5-flash-001")
|
||||
|
||||
prompt = "Why is the sky blue?"
|
||||
|
||||
# Prompt tokens count
|
||||
response = model.count_tokens(prompt)
|
||||
print(f"Prompt Token Count: {response.total_tokens}")
|
||||
print(f"Prompt Character Count: {response.total_billable_characters}")
|
||||
|
||||
# Send text to Gemini
|
||||
response = model.generate_content(prompt)
|
||||
|
||||
# Response tokens count
|
||||
usage_metadata = response.usage_metadata
|
||||
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
|
||||
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
|
||||
print(f"Total Token Count: {usage_metadata.total_token_count}")
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -84,10 +416,83 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
|
|||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Tuning API
|
||||
|
||||
Create Fine Tuning Job
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="py" label="Vertex Python SDK">
|
||||
|
||||
```python
|
||||
from typing import List, Optional
|
||||
from vertexai.preview.tuning import sft
|
||||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
|
||||
# TODO(developer): Update project
|
||||
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||
|
||||
sft_tuning_job = sft.train(
|
||||
source_model="gemini-1.0-pro-002",
|
||||
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
|
||||
)
|
||||
|
||||
# Polling for job completion
|
||||
while not sft_tuning_job.has_ended:
|
||||
time.sleep(60)
|
||||
sft_tuning_job.refresh()
|
||||
|
||||
print(sft_tuning_job.tuned_model_name)
|
||||
print(sft_tuning_job.tuned_model_endpoint_name)
|
||||
print(sft_tuning_job.experiment)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex-ai/tuningJobs \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -99,3 +504,7 @@ curl http://localhost:4000/vertex-ai/tuningJobs \
|
|||
}
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
|
|
@ -131,6 +131,56 @@ Expected Response
|
|||
}
|
||||
```
|
||||
|
||||
## Add Streaming Support
|
||||
|
||||
Here's a simple example of returning unix epoch seconds for both completion + streaming use-cases.
|
||||
|
||||
s/o [@Eloy Lafuente](https://github.com/stronk7) for this code example.
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Iterator, AsyncIterator
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm import CustomLLM, completion, acompletion
|
||||
|
||||
class UnixTimeLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
return completion(
|
||||
model="test/unixtime",
|
||||
mock_response=str(int(time.time())),
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
return await acompletion(
|
||||
model="test/unixtime",
|
||||
mock_response=str(int(time.time())),
|
||||
) # type: ignore
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": str(int(time.time())),
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
return generic_streaming_chunk # type: ignore
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": str(int(time.time())),
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
yield generic_streaming_chunk # type: ignore
|
||||
|
||||
unixtime = UnixTimeLLM()
|
||||
```
|
||||
|
||||
## Custom Handler Spec
|
||||
|
||||
```python
|
||||
|
|
|
@ -661,6 +661,7 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
|
|||
## Specifying Safety Settings
|
||||
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
|
||||
|
||||
### Set per model/request
|
||||
|
||||
<Tabs>
|
||||
|
||||
|
@ -752,6 +753,65 @@ response = client.chat.completions.create(
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Set Globally
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.set_verbose = True 👈 See RAW REQUEST/RESPONSE
|
||||
|
||||
litellm.vertex_ai_safety_settings = [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="vertex_ai/gemini-pro",
|
||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-experimental
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-experimental
|
||||
vertex_project: litellm-epic
|
||||
vertex_location: us-central1
|
||||
|
||||
litellm_settings:
|
||||
vertex_ai_safety_settings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_NONE
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Set Vertex Project & Vertex Location
|
||||
All calls using Vertex AI require the following parameters:
|
||||
* Your Project ID
|
||||
|
@ -1450,7 +1510,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
|||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||
|
||||
|
||||
## Embedding Models
|
||||
## **Embedding Models**
|
||||
|
||||
#### Usage - Embedding
|
||||
```python
|
||||
|
@ -1504,7 +1564,158 @@ response = litellm.embedding(
|
|||
)
|
||||
```
|
||||
|
||||
## Image Generation Models
|
||||
## **Multi-Modal Embeddings**
|
||||
|
||||
Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
response = await litellm.aembedding(
|
||||
model="vertex_ai/multimodalembedding@001",
|
||||
input=[
|
||||
{
|
||||
"image": {
|
||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
},
|
||||
"text": "this is a unicorn",
|
||||
},
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM PROXY (Unified Endpoint)">
|
||||
|
||||
1. Add model to config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: multimodalembedding@001
|
||||
litellm_params:
|
||||
model: vertex_ai/multimodalembedding@001
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request use OpenAI Python SDK
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
# # request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.embeddings.create(
|
||||
model="multimodalembedding@001",
|
||||
input = None,
|
||||
extra_body = {
|
||||
"instances": [
|
||||
{
|
||||
"image": {
|
||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
},
|
||||
"text": "this is a unicorn",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
|
||||
|
||||
1. Add model to config.yaml
|
||||
```yaml
|
||||
default_vertex_config:
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request use OpenAI Python SDK
|
||||
|
||||
```python
|
||||
import vertexai
|
||||
|
||||
from vertexai.vision_models import Image, MultiModalEmbeddingModel, Video
|
||||
from vertexai.vision_models import VideoSegmentConfig
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers['Authorization'] = f'Bearer {self.token}'
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials = credentials,
|
||||
api_transport="rest",
|
||||
|
||||
)
|
||||
|
||||
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||
image = Image.load_from_file(
|
||||
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
)
|
||||
|
||||
embeddings = model.get_embeddings(
|
||||
image=image,
|
||||
contextual_text="Colosseum",
|
||||
dimension=1408,
|
||||
)
|
||||
print(f"Image Embedding: {embeddings.image_embedding}")
|
||||
print(f"Text Embedding: {embeddings.text_embedding}")
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## **Image Generation Models**
|
||||
|
||||
Usage
|
||||
|
||||
|
|
|
@ -728,6 +728,7 @@ general_settings:
|
|||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
|
||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param
|
||||
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
||||
|
|
|
@ -101,8 +101,38 @@ Requirements:
|
|||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="key" label="Set on Key">
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"metadata": {
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}
|
||||
}
|
||||
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="team" label="Set on Team">
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"metadata": {
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}
|
||||
}
|
||||
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
Set `extra_body={"metadata": { }}` to `metadata` you want to pass
|
||||
|
@ -270,7 +300,42 @@ Requirements:
|
|||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="key" label="Set on Key">
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"metadata": {
|
||||
"spend_logs_metadata": {
|
||||
"hello": "world"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="team" label="Set on Team">
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"metadata": {
|
||||
"spend_logs_metadata": {
|
||||
"hello": "world"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
|
|
|
@ -61,6 +61,51 @@ litellm_settings:
|
|||
|
||||
Removes any field with `user_api_key_*` from metadata.
|
||||
|
||||
## What gets logged?
|
||||
|
||||
Found under `kwargs["standard_logging_payload"]`. This is a standard payload, logged for every response.
|
||||
|
||||
```python
|
||||
class StandardLoggingPayload(TypedDict):
|
||||
id: str
|
||||
call_type: str
|
||||
response_cost: float
|
||||
total_tokens: int
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
startTime: float
|
||||
endTime: float
|
||||
completionStartTime: float
|
||||
model_map_information: StandardLoggingModelInformation
|
||||
model: str
|
||||
model_id: Optional[str]
|
||||
model_group: Optional[str]
|
||||
api_base: str
|
||||
metadata: StandardLoggingMetadata
|
||||
cache_hit: Optional[bool]
|
||||
cache_key: Optional[str]
|
||||
saved_cache_cost: Optional[float]
|
||||
request_tags: list
|
||||
end_user: Optional[str]
|
||||
requester_ip_address: Optional[str]
|
||||
messages: Optional[Union[str, list, dict]]
|
||||
response: Optional[Union[str, list, dict]]
|
||||
model_parameters: dict
|
||||
hidden_params: StandardLoggingHiddenParams
|
||||
|
||||
class StandardLoggingHiddenParams(TypedDict):
|
||||
model_id: Optional[str]
|
||||
cache_key: Optional[str]
|
||||
api_base: Optional[str]
|
||||
response_cost: Optional[str]
|
||||
additional_headers: Optional[dict]
|
||||
|
||||
|
||||
class StandardLoggingModelInformation(TypedDict):
|
||||
model_map_key: str
|
||||
model_map_value: Optional[ModelInfo]
|
||||
```
|
||||
|
||||
## Logging Proxy Input/Output - Langfuse
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
|
||||
|
|
|
@ -334,3 +334,4 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
```
|
||||
Key=... over available RPM=0. Model RPM=100, Active keys=None
|
||||
```
|
||||
|
||||
|
|
|
@ -488,9 +488,34 @@ You can set:
|
|||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="per-team" label="Per Team">
|
||||
|
||||
Use `/team/new` or `/team/update`, to persist rate limits across multiple keys for a team.
|
||||
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/team/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{"team_id": "my-prod-team", "max_parallel_requests": 10, "tpm_limit": 20, "rpm_limit": 4}'
|
||||
```
|
||||
|
||||
[**See Swagger**](https://litellm-api.up.railway.app/#/team%20management/new_team_team_new_post)
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "sk-sA7VDkyhlQ7m8Gt77Mbt3Q",
|
||||
"expires": "2024-01-19T01:21:12.816168",
|
||||
"team_id": "my-prod-team",
|
||||
}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="per-user" label="Per Internal User">
|
||||
|
||||
Use `/user/new`, to persist rate limits across multiple keys.
|
||||
Use `/user/new` or `/user/update`, to persist rate limits across multiple keys for internal users.
|
||||
|
||||
|
||||
```shell
|
||||
|
@ -653,6 +678,70 @@ curl --location 'http://localhost:4000/chat/completions' \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Set default budget for ALL internal users
|
||||
|
||||
Use this to set a default budget for users who you give keys to.
|
||||
|
||||
This will apply when a user has [`user_role="internal_user"`](./self_serve.md#available-roles) (set this via `/user/new` or `/user/update`).
|
||||
|
||||
This will NOT apply if a key has a team_id (team budgets will apply then). [Tell us how we can improve this!](https://github.com/BerriAI/litellm/issues)
|
||||
|
||||
1. Define max budget in your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "gpt-3.5-turbo"
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
max_internal_user_budget: 0 # amount in USD
|
||||
internal_user_budget_duration: "1mo" # reset every month
|
||||
```
|
||||
|
||||
2. Create key for user
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
...
|
||||
"key": "sk-X53RdxnDhzamRwjKXR4IHg"
|
||||
}
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-X53RdxnDhzamRwjKXR4IHg' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"error": {
|
||||
"message": "ExceededBudget: User=<user_id> over budget. Spend=3.7e-05, Budget=0.0",
|
||||
"type": "budget_exceeded",
|
||||
"param": null,
|
||||
"code": "400"
|
||||
}
|
||||
}
|
||||
```
|
||||
## Grant Access to new model
|
||||
|
||||
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.).
|
||||
|
|
8
litellm-js/spend-logs/package-lock.json
generated
8
litellm-js/spend-logs/package-lock.json
generated
|
@ -6,7 +6,7 @@
|
|||
"": {
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.10.1",
|
||||
"hono": "^4.2.7"
|
||||
"hono": "^4.5.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
@ -463,9 +463,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.2.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
|
||||
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
|
||||
"version": "4.5.8",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
|
||||
"integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
},
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.10.1",
|
||||
"hono": "^4.2.7"
|
||||
"hono": "^4.5.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
|
|
@ -339,6 +339,7 @@ api_version = None
|
|||
organization = None
|
||||
project = None
|
||||
config_path = None
|
||||
vertex_ai_safety_settings: Optional[dict] = None
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
open_ai_text_completion_models: List = []
|
||||
|
|
|
@ -98,6 +98,10 @@ class LangsmithLogger(CustomLogger):
|
|||
project_name = metadata.get("project_name", self.langsmith_project)
|
||||
run_name = metadata.get("run_name", self.langsmith_default_run_name)
|
||||
run_id = metadata.get("id", None)
|
||||
parent_run_id = metadata.get("parent_run_id", None)
|
||||
trace_id = metadata.get("trace_id", None)
|
||||
session_id = metadata.get("session_id", None)
|
||||
dotted_order = metadata.get("dotted_order", None)
|
||||
tags = metadata.get("tags", []) or []
|
||||
verbose_logger.debug(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
|
@ -149,6 +153,18 @@ class LangsmithLogger(CustomLogger):
|
|||
if run_id:
|
||||
data["id"] = run_id
|
||||
|
||||
if parent_run_id:
|
||||
data["parent_run_id"] = parent_run_id
|
||||
|
||||
if trace_id:
|
||||
data["trace_id"] = trace_id
|
||||
|
||||
if session_id:
|
||||
data["session_id"] = session_id
|
||||
|
||||
if dotted_order:
|
||||
data["dotted_order"] = dotted_order
|
||||
|
||||
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
|
||||
|
||||
return data
|
||||
|
|
|
@ -524,6 +524,7 @@ class Logging:
|
|||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Calculate response cost using result + logging object variables.
|
||||
|
@ -535,10 +536,13 @@ class Logging:
|
|||
litellm_params=self.litellm_params
|
||||
)
|
||||
|
||||
if cache_hit is None:
|
||||
cache_hit = self.model_call_details.get("cache_hit", False)
|
||||
|
||||
response_cost = litellm.response_cost_calculator(
|
||||
response_object=result,
|
||||
model=self.model,
|
||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
||||
cache_hit=cache_hit,
|
||||
custom_llm_provider=self.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
|
@ -630,6 +634,7 @@ class Logging:
|
|||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
)
|
||||
)
|
||||
return start_time, end_time, result
|
||||
|
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
|
|||
init_response_obj: Any,
|
||||
start_time: dt_object,
|
||||
end_time: dt_object,
|
||||
logging_obj: Logging,
|
||||
) -> Optional[StandardLoggingPayload]:
|
||||
try:
|
||||
if kwargs is None:
|
||||
|
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
|
|||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
saved_cache_cost: Optional[float] = None
|
||||
if cache_hit is True:
|
||||
import time
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
|
||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||
result=init_response_obj, cache_hit=False
|
||||
)
|
||||
|
||||
## Get model cost information ##
|
||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||
|
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
|
|||
id=str(id),
|
||||
call_type=call_type or "",
|
||||
cache_hit=cache_hit,
|
||||
saved_cache_cost=saved_cache_cost,
|
||||
startTime=start_time_float,
|
||||
endTime=end_time_float,
|
||||
completionStartTime=completion_start_time_float,
|
||||
|
|
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
]
|
||||
|
||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
## TRANSFORMATION ##
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if (
|
||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
for key in additional_request_keys:
|
||||
inference_params.pop(key, None)
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
|
|
|
@ -124,12 +124,14 @@ class CohereConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -144,11 +146,12 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
|
@ -338,13 +341,14 @@ def embedding(
|
|||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -203,13 +205,14 @@ def completion(
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
|
|
|
@ -4,14 +4,17 @@ import traceback
|
|||
import types
|
||||
import uuid
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
|
@ -175,7 +178,7 @@ class OllamaChatConfig:
|
|||
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider="ollama_chat"
|
||||
model=model, custom_llm_provider="ollama"
|
||||
)
|
||||
if model_info.get("supports_function_calling") is True:
|
||||
optional_params["tools"] = value
|
||||
|
@ -237,13 +240,30 @@ def get_ollama_response(
|
|||
function_name = optional_params.pop("function_name", None)
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
if "role" in m and m["role"] == "tool":
|
||||
m["role"] = "assistant"
|
||||
if isinstance(
|
||||
m, BaseModel
|
||||
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||
m = m.model_dump(exclude_none=True)
|
||||
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
|
||||
new_tools: List[OllamaToolCall] = []
|
||||
for tool in m["tool_calls"]:
|
||||
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||
if typed_tool["type"] == "function":
|
||||
ollama_tool_call = OllamaToolCall(
|
||||
function=OllamaToolCallFunction(
|
||||
name=typed_tool["function"]["name"],
|
||||
arguments=json.loads(typed_tool["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
new_tools.append(ollama_tool_call)
|
||||
m["tool_calls"] = new_tools
|
||||
new_messages.append(m)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": new_messages,
|
||||
"options": optional_params,
|
||||
"stream": stream,
|
||||
}
|
||||
|
@ -263,7 +283,7 @@ def get_ollama_response(
|
|||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
|
@ -283,7 +303,7 @@ def get_ollama_response(
|
|||
function_name=function_name,
|
||||
)
|
||||
return response
|
||||
elif stream == True:
|
||||
elif stream is True:
|
||||
return ollama_completion_stream(
|
||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||
)
|
||||
|
|
|
@ -84,6 +84,8 @@ class MistralConfig:
|
|||
|
||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||
|
||||
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
||||
|
||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
|
||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||
|
@ -99,6 +101,7 @@ class MistralConfig:
|
|||
random_seed: Optional[int] = None
|
||||
safe_prompt: Optional[bool] = None
|
||||
response_format: Optional[dict] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -110,6 +113,7 @@ class MistralConfig:
|
|||
random_seed: Optional[int] = None,
|
||||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
stop: Optional[Union[str, list]] = None
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
|
@ -143,6 +147,7 @@ class MistralConfig:
|
|||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"stop",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
|
@ -166,6 +171,8 @@ class MistralConfig:
|
|||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "tool_choice" and isinstance(value, str):
|
||||
optional_params["tool_choice"] = self._map_tool_choice(
|
||||
tool_choice=value
|
||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
|||
|
||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||
|
||||
# used to interweave user messages, to ensure user/assistant alternating
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||
"role": "user",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||
"role": "assistant",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
|
||||
def map_system_message_pt(messages: list) -> list:
|
||||
"""
|
||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
|||
messages: List,
|
||||
model: str,
|
||||
llm_provider: str,
|
||||
user_continue_message: Optional[dict] = None,
|
||||
) -> List[BedrockMessageBlock]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Bedrock format
|
||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
|||
|
||||
contents: List[BedrockMessageBlock] = []
|
||||
msg_i = 0
|
||||
|
||||
# if initial message is assistant message
|
||||
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.insert(0, user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
# if final message is assistant message
|
||||
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.append(user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
while msg_i < len(messages):
|
||||
user_content: List[BedrockContentBlock] = []
|
||||
init_msg_i = msg_i
|
||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
|||
model=model,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import types
|
|||
import uuid
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
@ -38,12 +38,15 @@ from litellm.types.llms.vertex_ai import (
|
|||
FunctionDeclaration,
|
||||
GenerateContentResponseBody,
|
||||
GenerationConfig,
|
||||
Instance,
|
||||
InstanceVideo,
|
||||
PartType,
|
||||
RequestBody,
|
||||
SafetSettingsConfig,
|
||||
SystemInstructions,
|
||||
ToolConfig,
|
||||
Tools,
|
||||
VertexMultimodalEmbeddingRequest,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
@ -188,9 +191,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
|||
elif value["type"] == "text": # type: ignore
|
||||
optional_params["response_mime_type"] = "text/plain"
|
||||
if "response_schema" in value: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
||||
elif value["type"] == "json_schema": # type: ignore
|
||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
|
@ -400,9 +405,11 @@ class VertexGeminiConfig:
|
|||
elif value["type"] == "text":
|
||||
optional_params["response_mime_type"] = "text/plain"
|
||||
if "response_schema" in value:
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["response_schema"]
|
||||
elif value["type"] == "json_schema": # type: ignore
|
||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
|
@ -594,6 +601,10 @@ class VertexLLM(BaseLLM):
|
|||
self._credentials: Optional[Any] = None
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
||||
"multimodalembedding",
|
||||
"multimodalembedding@001",
|
||||
]
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
|
@ -1537,6 +1548,160 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return model_response
|
||||
|
||||
def multimodal_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
):
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {}
|
||||
|
||||
request_data = VertexMultimodalEmbeddingRequest()
|
||||
|
||||
if "instances" in optional_params:
|
||||
request_data["instances"] = optional_params["instances"]
|
||||
elif isinstance(input, list):
|
||||
request_data["instances"] = input
|
||||
else:
|
||||
# construct instances
|
||||
vertex_request_instance = Instance(**optional_params)
|
||||
|
||||
if isinstance(input, str):
|
||||
vertex_request_instance["text"] = input
|
||||
|
||||
request_data["instances"] = [vertex_request_instance]
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_multimodal_embedding(
|
||||
model=model,
|
||||
api_base=url,
|
||||
data=request_data,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
if "predictions" not in _json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
model_response.data = _predictions
|
||||
model_response.model = model
|
||||
|
||||
return model_response
|
||||
|
||||
async def async_multimodal_embedding(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: VertexMultimodalEmbeddingRequest,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> litellm.EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
_json_response = response.json()
|
||||
if "predictions" not in _json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
model_response.data = _predictions
|
||||
model_response.model = model
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
|
|
|
@ -943,6 +943,7 @@ def completion(
|
|||
output_cost_per_token=output_cost_per_token,
|
||||
cooldown_time=cooldown_time,
|
||||
text_completion=kwargs.get("text_completion"),
|
||||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -1634,6 +1635,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1644,6 +1652,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
@ -1674,6 +1683,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1682,6 +1698,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
|
@ -2288,7 +2305,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
|
@ -2464,7 +2481,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
)
|
||||
if acompletion is True or optional_params.get("stream", False) == True:
|
||||
if acompletion is True or optional_params.get("stream", False) is True:
|
||||
return generator
|
||||
|
||||
response = generator
|
||||
|
@ -3158,6 +3175,7 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -3229,6 +3247,7 @@ def embedding(
|
|||
"model_config",
|
||||
"cooldown_time",
|
||||
"tags",
|
||||
"extra_headers",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3292,7 +3311,7 @@ def embedding(
|
|||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
|
@ -3398,12 +3417,18 @@ def embedding(
|
|||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||
headers = extra_headers
|
||||
else:
|
||||
headers = {}
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key, # type: ignore
|
||||
headers=headers,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
|
@ -3477,6 +3502,26 @@ def embedding(
|
|||
or get_secret("VERTEX_CREDENTIALS")
|
||||
)
|
||||
|
||||
if (
|
||||
"image" in optional_params
|
||||
or "video" in optional_params
|
||||
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
|
||||
):
|
||||
# multimodal embedding is supported on vertex httpx
|
||||
response = vertex_chat_completion.multimodal_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
else:
|
||||
response = vertex_ai.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
|
|
@ -3,9 +3,11 @@ model_list:
|
|||
litellm_params:
|
||||
model: "*"
|
||||
|
||||
|
||||
litellm_settings:
|
||||
cache: True
|
||||
cache_params:
|
||||
type: redis
|
||||
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}]
|
||||
success_callback: ["s3"]
|
||||
cache: true
|
||||
s3_callback_params:
|
||||
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
@ -21,6 +21,13 @@ else:
|
|||
Span = Any
|
||||
|
||||
|
||||
class LiteLLMTeamRoles(enum.Enum):
|
||||
# team admin
|
||||
TEAM_ADMIN = "admin"
|
||||
# team member
|
||||
TEAM_MEMBER = "user"
|
||||
|
||||
|
||||
class LitellmUserRoles(str, enum.Enum):
|
||||
"""
|
||||
Admin Roles:
|
||||
|
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
|
|||
+ sso_only_routes
|
||||
)
|
||||
|
||||
self_managed_routes: List = [
|
||||
"/team/member_add",
|
||||
"/team/member_delete",
|
||||
] # routes that manage their own allowed/disallowed logic
|
||||
|
||||
|
||||
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
||||
# """
|
||||
|
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
soft_budget: Optional[float] = None
|
||||
team_model_aliases: Optional[Dict] = None
|
||||
team_member_spend: Optional[float] = None
|
||||
team_member: Optional[Member] = None
|
||||
team_metadata: Optional[Dict] = None
|
||||
|
||||
# End User Params
|
||||
|
|
|
@ -975,8 +975,6 @@ async def user_api_key_auth(
|
|||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
if is_llm_api_route(route=route):
|
||||
pass
|
||||
elif is_llm_api_route(route=request["route"].name):
|
||||
pass
|
||||
elif (
|
||||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
|
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
|
|||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||
and route in LiteLLMRoutes.internal_user_routes.value
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
route in LiteLLMRoutes.self_managed_routes.value
|
||||
): # routes that manage their own allowed/disallowed logic
|
||||
pass
|
||||
else:
|
||||
user_role = "unknown"
|
||||
user_id = "unknown"
|
||||
|
|
|
@ -120,6 +120,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
if data is None:
|
||||
data = {}
|
||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
|
|
|
@ -119,6 +119,7 @@ async def new_user(
|
|||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
if data.send_invite_email is True:
|
||||
|
@ -732,7 +733,7 @@ async def delete_user(
|
|||
delete user and associated user keys
|
||||
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/team/delete' \
|
||||
curl --location 'http://0.0.0.0:8000/user/delete' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
|
|
|
@ -849,7 +849,7 @@ async def generate_key_helper_fn(
|
|||
}
|
||||
|
||||
if (
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -30,7 +30,7 @@ from litellm.proxy._types import (
|
|||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
|
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
for member in team_obj.members_with_roles:
|
||||
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
#### TEAM MANAGEMENT ####
|
||||
@router.post(
|
||||
"/team/new",
|
||||
|
@ -417,6 +427,7 @@ async def team_member_add(
|
|||
|
||||
If user doesn't exist, new user row will also be added to User Table
|
||||
|
||||
Only proxy_admin or admin of team, allowed to access this endpoint.
|
||||
```
|
||||
|
||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||
|
@ -465,6 +476,25 @@ async def team_member_add(
|
|||
|
||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||
|
||||
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||
|
||||
if (
|
||||
hasattr(user_api_key_dict, "user_role")
|
||||
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||
"/team/member_add",
|
||||
complete_team_data.team_id,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(data.member, Member):
|
||||
# add to team db
|
||||
new_member = data.member
|
||||
|
@ -569,6 +599,23 @@ async def team_member_delete(
|
|||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||
"/team/member_delete", existing_team_row.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## DELETE MEMBER FROM TEAM
|
||||
new_team_members: List[Member] = []
|
||||
for m in existing_team_row.members_with_roles:
|
||||
|
|
|
@ -266,7 +266,7 @@ def management_endpoint_wrapper(func):
|
|||
)
|
||||
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
|
@ -310,7 +310,7 @@ def management_endpoint_wrapper(func):
|
|||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
|
|
|
@ -301,16 +301,19 @@ async def pass_through_request(
|
|||
request=request, headers=headers, forward_headers=forward_headers
|
||||
)
|
||||
|
||||
_parsed_body = None
|
||||
if custom_body:
|
||||
_parsed_body = custom_body
|
||||
else:
|
||||
request_body = await request.body()
|
||||
if request_body == b"" or request_body is None:
|
||||
_parsed_body = None
|
||||
else:
|
||||
body_str = request_body.decode()
|
||||
try:
|
||||
_parsed_body = ast.literal_eval(body_str)
|
||||
except Exception:
|
||||
_parsed_body = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||
url, headers, _parsed_body
|
||||
|
@ -320,7 +323,7 @@ async def pass_through_request(
|
|||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=_parsed_body,
|
||||
data=_parsed_body or {},
|
||||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
|
@ -360,11 +363,20 @@ async def pass_through_request(
|
|||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params = query_params or request.query_params.__dict__
|
||||
requested_query_params: Optional[dict] = (
|
||||
query_params or request.query_params.__dict__
|
||||
)
|
||||
if requested_query_params == request.query_params.__dict__:
|
||||
requested_query_params = None
|
||||
|
||||
requested_query_params_str = None
|
||||
if requested_query_params:
|
||||
requested_query_params_str = "&".join(
|
||||
f"{k}={v}" for k, v in requested_query_params.items()
|
||||
)
|
||||
|
||||
logging_url = str(url)
|
||||
if requested_query_params_str:
|
||||
if "?" in str(url):
|
||||
logging_url = str(url) + "&" + requested_query_params_str
|
||||
else:
|
||||
|
@ -409,6 +421,14 @@ async def pass_through_request(
|
|||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("request method: {}".format(request.method))
|
||||
verbose_proxy_logger.debug("request url: {}".format(url))
|
||||
verbose_proxy_logger.debug("request headers: {}".format(headers))
|
||||
verbose_proxy_logger.debug(
|
||||
"requested_query_params={}".format(requested_query_params)
|
||||
)
|
||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: openai-embedding
|
||||
litellm_params:
|
||||
model: openai/text-embedding-3-small
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
cache_params:
|
||||
type: qdrant-semantic
|
||||
qdrant_semantic_cache_embedding_model: openai-embedding
|
||||
qdrant_collection_name: test_collection
|
||||
qdrant_quantization_config: binary
|
||||
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
||||
guardrails:
|
||||
- guardrail_name: "lakera-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
api_key: os.environ/LAKERA_API_KEY
|
||||
api_base: os.environ/LAKERA_API_BASE
|
||||
category_thresholds:
|
||||
prompt_injection: 0.1
|
||||
jailbreak: 0.1
|
||||
|
|
@ -1498,6 +1498,11 @@ class ProxyConfig:
|
|||
litellm.get_secret(secret_name=key, default_value=value)
|
||||
)
|
||||
|
||||
# check if litellm_license in general_settings
|
||||
if "LITELLM_LICENSE" in environment_variables:
|
||||
_license_check.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", None)
|
||||
if litellm_settings is None:
|
||||
|
@ -1878,6 +1883,11 @@ class ProxyConfig:
|
|||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
# check if litellm_license in general_settings
|
||||
if "litellm_license" in general_settings:
|
||||
_license_check.license_str = general_settings["litellm_license"]
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
router_params: dict = {
|
||||
"cache_responses": litellm.cache
|
||||
!= None, # cache if user passed in cache values
|
||||
|
@ -2784,11 +2794,14 @@ async def startup_event():
|
|||
await custom_db_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
litellm_proxy_admin_name = os.getenv(
|
||||
"PROXY_ADMIN_ID", litellm_proxy_admin_name
|
||||
)
|
||||
if general_settings.get("disable_adding_master_key_hash_to_db") is True:
|
||||
verbose_proxy_logger.info("Skipping writing master key hash to db")
|
||||
else:
|
||||
# add master key to db
|
||||
asyncio.create_task(
|
||||
generate_key_helper_fn(
|
||||
request_type="user",
|
||||
|
@ -3011,6 +3024,29 @@ async def chat_completion(
|
|||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
|
||||
Follows the exact same API spec as `OpenAI's Chat API https://platform.openai.com/docs/api-reference/chat`
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/chat/completions \
|
||||
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
"""
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
|
||||
data = {}
|
||||
|
@ -3268,6 +3304,24 @@ async def completion(
|
|||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Follows the exact same API spec as `OpenAI's Completions API https://platform.openai.com/docs/api-reference/completions`
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/completions \
|
||||
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
"""
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
data = {}
|
||||
try:
|
||||
|
@ -3474,6 +3528,23 @@ async def embeddings(
|
|||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Follows the exact same API spec as `OpenAI's Embeddings API https://platform.openai.com/docs/api-reference/embeddings`
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/embeddings \
|
||||
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
|
||||
-d '{
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "The quick brown fox jumps over the lazy dog"
|
||||
}'
|
||||
```
|
||||
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Any = {}
|
||||
try:
|
||||
|
@ -3481,6 +3552,11 @@ async def embeddings(
|
|||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n%s",
|
||||
json.dumps(data, indent=4),
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import secrets
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
|
@ -8,12 +10,30 @@ from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
|||
from litellm.proxy.utils import hash_token
|
||||
|
||||
|
||||
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
||||
if _master_key is None:
|
||||
return False
|
||||
|
||||
## string comparison
|
||||
is_master_key = secrets.compare_digest(api_key, _master_key)
|
||||
if is_master_key:
|
||||
return True
|
||||
|
||||
## hash comparison
|
||||
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
|
||||
if is_master_key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_logging_payload(
|
||||
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
||||
) -> SpendLogsPayload:
|
||||
from pydantic import Json
|
||||
|
||||
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||
from litellm.proxy.proxy_server import general_settings, master_key
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
|
||||
|
@ -36,9 +56,15 @@ def get_logging_payload(
|
|||
usage = dict(usage)
|
||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
api_key = metadata.get("user_api_key", "")
|
||||
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||
if api_key is not None and isinstance(api_key, str):
|
||||
if api_key.startswith("sk-"):
|
||||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
if (
|
||||
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||
):
|
||||
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = metadata.get("model_group", "")
|
||||
|
|
21
litellm/proxy/tests/test_vtx_embedding.py
Normal file
21
litellm/proxy/tests/test_vtx_embedding.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
# # request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.embeddings.create(
|
||||
model="multimodalembedding@001",
|
||||
input=[],
|
||||
extra_body={
|
||||
"instances": [
|
||||
{
|
||||
"image": {
|
||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
},
|
||||
"text": "this is a unicorn",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
print(response)
|
58
litellm/proxy/tests/test_vtx_sdk_embedding.py
Normal file
58
litellm/proxy/tests/test_vtx_sdk_embedding.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import vertexai
|
||||
from google.auth.credentials import Credentials
|
||||
from vertexai.vision_models import (
|
||||
Image,
|
||||
MultiModalEmbeddingModel,
|
||||
Video,
|
||||
VideoSegmentConfig,
|
||||
)
|
||||
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
class CredentialsWrapper(Credentials):
|
||||
def __init__(self, token=None):
|
||||
super().__init__()
|
||||
self.token = token
|
||||
self.expiry = None # or set to a future date if needed
|
||||
|
||||
def refresh(self, request):
|
||||
pass
|
||||
|
||||
def apply(self, headers, token=None):
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
return False # Always consider the token as non-expired
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True # Always consider the credentials as valid
|
||||
|
||||
|
||||
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||
|
||||
vertexai.init(
|
||||
project="adroit-crow-413218",
|
||||
location="us-central1",
|
||||
api_endpoint=LITELLM_PROXY_BASE,
|
||||
credentials=credentials,
|
||||
api_transport="rest",
|
||||
)
|
||||
|
||||
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||
image = Image.load_from_file(
|
||||
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
)
|
||||
|
||||
embeddings = model.get_embeddings(
|
||||
image=image,
|
||||
contextual_text="Colosseum",
|
||||
dimension=1408,
|
||||
)
|
||||
print(f"Image Embedding: {embeddings.image_embedding}")
|
||||
print(f"Text Embedding: {embeddings.text_embedding}")
|
|
@ -44,6 +44,7 @@ from litellm.proxy._types import (
|
|||
DynamoDBArgs,
|
||||
LiteLLM_VerificationTokenView,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
ResetTeamBudgetRequest,
|
||||
SpendLogsMetadata,
|
||||
SpendLogsPayload,
|
||||
|
@ -1395,6 +1396,7 @@ class PrismaClient:
|
|||
t.blocked AS team_blocked,
|
||||
t.team_alias AS team_alias,
|
||||
t.metadata AS team_metadata,
|
||||
t.members_with_roles AS team_members_with_roles,
|
||||
tm.spend AS team_member_spend,
|
||||
m.aliases as team_model_aliases
|
||||
FROM "LiteLLM_VerificationToken" AS v
|
||||
|
@ -1412,6 +1414,33 @@ class PrismaClient:
|
|||
response["team_models"] = []
|
||||
if response["team_blocked"] is None:
|
||||
response["team_blocked"] = False
|
||||
|
||||
team_member: Optional[Member] = None
|
||||
if (
|
||||
response["team_members_with_roles"] is not None
|
||||
and response["user_id"] is not None
|
||||
):
|
||||
## find the team member corresponding to user id
|
||||
"""
|
||||
[
|
||||
{
|
||||
"role": "admin",
|
||||
"user_id": "default_user_id",
|
||||
"user_email": null
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"user_id": null,
|
||||
"user_email": "test@email.com"
|
||||
}
|
||||
]
|
||||
"""
|
||||
for tm in response["team_members_with_roles"]:
|
||||
if tm.get("user_id") is not None and response[
|
||||
"user_id"
|
||||
] == tm.get("user_id"):
|
||||
team_member = Member(**tm)
|
||||
response["team_member"] = team_member
|
||||
response = LiteLLM_VerificationTokenView(
|
||||
**response, last_refreshed_at=time.time()
|
||||
)
|
||||
|
|
|
@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
|
|||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
|
|||
)
|
||||
|
||||
|
||||
async def execute_post_vertex_ai_request(
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
route: str,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
|
||||
if default_vertex_config is None:
|
||||
|
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
|
|||
vertex_project = default_vertex_config.get("vertex_project", None)
|
||||
vertex_location = default_vertex_config.get("vertex_location", None)
|
||||
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
request_data_json = {}
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
if len(body_str) > 0:
|
||||
try:
|
||||
request_data_json = ast.literal_eval(body_str)
|
||||
except:
|
||||
request_data_json = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(
|
||||
json.dumps(request_data_json, indent=4)
|
||||
),
|
||||
)
|
||||
|
||||
response = (
|
||||
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
|
||||
request_data=request_data_json,
|
||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||
model="",
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
request_route=route,
|
||||
)
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
return response
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_generate_content(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:generateContent",
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_predict_endpoint(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /predict endpoint
|
||||
Use this for:
|
||||
- Embeddings API - Text Embedding, Multi Modal Embedding
|
||||
- Imagen API
|
||||
- Code Completion API
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"instances":[{"content": "gm"}]}'
|
||||
```
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:predict",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_countTokens_endpoint(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
|
||||
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:countTokens",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/batchPredictionJobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_batch_prediction_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/batchPredictionJobs",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/tuningJobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_fine_tuning_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/tuningJobs",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_cancel_fine_tuning_job(
|
||||
request: Request,
|
||||
job_id: str,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/tuningJobs/{job_id}:cancel",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/cachedContents",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_add_cached_content(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/cachedContents",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
return received_value
|
||||
|
|
|
@ -15,7 +15,7 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -501,6 +501,8 @@ async def test_async_vertexai_streaming_response():
|
|||
assert len(complete_response) > 0
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.APIConnectionError:
|
||||
pass
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
|
@ -955,6 +957,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
assert isinstance(response._hidden_params["response_cost"], float)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "429 Quota exceeded" in str(e):
|
||||
pass
|
||||
|
@ -1004,7 +1008,9 @@ async def test_partner_models_httpx_streaming(model, sync_mode):
|
|||
idx += 1
|
||||
|
||||
print(f"response: {response}")
|
||||
except litellm.RateLimitError:
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "429 Quota exceeded" in str(e):
|
||||
|
@ -1558,6 +1564,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
|||
"response_schema"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
"response_mime_type"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
mock_call.call_args.kwargs["json"]["generationConfig"][
|
||||
"response_mime_type"
|
||||
]
|
||||
== "application/json"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"response_schema"
|
||||
|
@ -1826,6 +1842,71 @@ def test_vertexai_embedding():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertexai_multimodal_embedding():
|
||||
load_vertex_ai_credentials()
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"predictions": [
|
||||
{
|
||||
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
|
||||
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"instances": [
|
||||
{
|
||||
"image": {
|
||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
},
|
||||
"text": "this is a unicorn",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
# Act: Call the litellm.aembedding function
|
||||
response = await litellm.aembedding(
|
||||
model="vertex_ai/multimodalembedding@001",
|
||||
input=[
|
||||
{
|
||||
"image": {
|
||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||
},
|
||||
"text": "this is a unicorn",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
_, kwargs = mock_post.call_args
|
||||
args_to_vertexai = kwargs["json"]
|
||||
|
||||
print("args to vertex ai call:", args_to_vertexai)
|
||||
|
||||
assert args_to_vertexai == expected_payload
|
||||
assert response.model == "multimodalembedding@001"
|
||||
assert len(response.data) == 1
|
||||
response_data = response.data[0]
|
||||
assert "imageEmbedding" in response_data
|
||||
assert "textEmbedding" in response_data
|
||||
|
||||
# Optional: Print for debugging
|
||||
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||
print("Response:", response)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="new test - works locally running into vertex version issues on ci/cd"
|
||||
)
|
||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
|||
"temperature": 0.3,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": "hey, how's it going?"},
|
||||
{"role": "assistant", "content": "hey, how's it going?"},
|
||||
],
|
||||
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||
}
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/{}".format(model),
|
||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
|||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1252,3 +1252,48 @@ def test_standard_logging_payload(model, turn_off_message_logging):
|
|||
]
|
||||
if turn_off_message_logging:
|
||||
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
|
||||
def test_aaastandard_logging_payload_cache_hit():
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
# sync completion
|
||||
|
||||
litellm.cache = Cache()
|
||||
|
||||
_ = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
caching=True,
|
||||
)
|
||||
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
litellm.success_callback = []
|
||||
|
||||
with patch.object(
|
||||
customHandler, "log_success_event", new=MagicMock()
|
||||
) as mock_client:
|
||||
_ = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
caching=True,
|
||||
)
|
||||
|
||||
time.sleep(2)
|
||||
mock_client.assert_called_once()
|
||||
|
||||
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
|
||||
assert (
|
||||
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||
is not None
|
||||
)
|
||||
|
||||
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
|
||||
"kwargs"
|
||||
]["standard_logging_object"]
|
||||
|
||||
assert standard_logging_object["cache_hit"] is True
|
||||
assert standard_logging_object["response_cost"] == 0
|
||||
assert standard_logging_object["saved_cache_cost"] > 0
|
||||
|
|
|
@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
|
|||
)
|
||||
def test_parallel_function_call(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
# Step 1: send the conversation and available functions to the model
|
||||
messages = [
|
||||
{
|
||||
|
@ -141,6 +142,8 @@ def test_parallel_function_call(model):
|
|||
drop_params=True,
|
||||
) # get a new response from the model where it can see the function response
|
||||
print("second response\n", second_response)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -322,6 +325,7 @@ def test_groq_parallel_function_call():
|
|||
location=function_args.get("location"),
|
||||
unit=function_args.get("unit"),
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
|
@ -337,27 +341,3 @@ def test_groq_parallel_function_call():
|
|||
print("second response\n", second_response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["gemini/gemini-1.5-pro"])
|
||||
def test_simple_function_call_function_param(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "plot",
|
||||
"description": "Generate plots",
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -116,6 +116,8 @@ async def test_async_image_generation_openai():
|
|||
)
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
except litellm.APIError:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
|
|
|
@ -2328,6 +2328,11 @@ async def test_master_key_hashing(prisma_client):
|
|||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
_team_id = "ishaans-special-team_{}".format(uuid.uuid4())
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
)
|
||||
await new_team(
|
||||
NewTeamRequest(team_id=_team_id),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
|
@ -2343,7 +2348,8 @@ async def test_master_key_hashing(prisma_client):
|
|||
models=["azure-gpt-3.5"],
|
||||
team_id=_team_id,
|
||||
tpm_limit=20,
|
||||
)
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
print(_response)
|
||||
assert _response.models == ["azure-gpt-3.5"]
|
||||
|
|
|
@ -19,7 +19,11 @@ from litellm.types.completion import (
|
|||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from litellm.utils import get_optional_params, get_optional_params_embeddings
|
||||
from litellm.utils import (
|
||||
get_optional_params,
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen,
|
||||
)
|
||||
|
||||
## get_optional_params_embeddings
|
||||
### Models: OpenAI, Azure, Bedrock
|
||||
|
@ -430,7 +434,6 @@ def test_get_optional_params_image_gen():
|
|||
print(response)
|
||||
|
||||
assert "aws_region_name" not in response
|
||||
|
||||
response = litellm.utils.get_optional_params_image_gen(
|
||||
aws_region_name="us-east-1", custom_llm_provider="bedrock"
|
||||
)
|
||||
|
@ -463,3 +466,36 @@ def test_get_optional_params_num_retries():
|
|||
|
||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||
assert mock_client.call_args.kwargs["max_retries"] == 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
[
|
||||
"vertex_ai",
|
||||
"vertex_ai_beta",
|
||||
],
|
||||
)
|
||||
def test_vertex_safety_settings(provider):
|
||||
litellm.vertex_ai_safety_settings = [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
]
|
||||
|
||||
optional_params = get_optional_params(
|
||||
model="gemini-1.5-pro", custom_llm_provider=provider
|
||||
)
|
||||
assert len(optional_params) == 1
|
||||
|
|
|
@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
|
||||
await team_member_add(
|
||||
data=team_member_add_request,
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
||||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
|
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
||||
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
||||
prisma_client, team_member_role, team_route
|
||||
):
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||
from litellm.proxy.proxy_server import (
|
||||
ProxyException,
|
||||
hash_token,
|
||||
user_api_key_auth,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm, "max_internal_user_budget", 10)
|
||||
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
user = f"ishaan {uuid.uuid4().hex}"
|
||||
_team_id = "litellm-test-client-id-new"
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id,
|
||||
token=hash_token(user_key),
|
||||
team_member=Member(role=team_member_role, user_id=user),
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
|
||||
team_obj = LiteLLM_TeamTableCachedObj(
|
||||
team_id=_team_id,
|
||||
blocked=False,
|
||||
last_refreshed_at=time.time(),
|
||||
metadata={"guardrails": {"modify_guardrails": False}},
|
||||
)
|
||||
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
|
||||
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
||||
import json
|
||||
|
||||
from starlette.datastructures import URL
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url=team_route)
|
||||
|
||||
body = {}
|
||||
json_bytes = json.dumps(body).encode("utf-8")
|
||||
|
||||
request._body = json_bytes
|
||||
|
||||
## ALLOWED BY USER_API_KEY_AUTH
|
||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_member_add_team_admin(
|
||||
prisma_client, new_member_method, user_role
|
||||
):
|
||||
"""
|
||||
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
||||
|
||||
Allow team admins to:
|
||||
- Add and remove team members
|
||||
- raise error if team member not an existing 'internal_user'
|
||||
"""
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||
from litellm.proxy.proxy_server import (
|
||||
HTTPException,
|
||||
ProxyException,
|
||||
hash_token,
|
||||
user_api_key_auth,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm, "max_internal_user_budget", 10)
|
||||
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
user = f"ishaan {uuid.uuid4().hex}"
|
||||
_team_id = "litellm-test-client-id-new"
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id,
|
||||
user_id=user,
|
||||
token=hash_token(user_key),
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
|
||||
team_obj = LiteLLM_TeamTableCachedObj(
|
||||
team_id=_team_id,
|
||||
blocked=False,
|
||||
last_refreshed_at=time.time(),
|
||||
members_with_roles=[Member(role=user_role, user_id=user)],
|
||||
metadata={"guardrails": {"modify_guardrails": False}},
|
||||
)
|
||||
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
if new_member_method == "user_id":
|
||||
data = {
|
||||
"team_id": _team_id,
|
||||
"member": [{"role": "user", "user_id": user}],
|
||||
}
|
||||
elif new_member_method == "user_email":
|
||||
data = {
|
||||
"team_id": _team_id,
|
||||
"member": [{"role": "user", "user_email": user}],
|
||||
}
|
||||
team_member_add_request = TeamMemberAddRequest(**data)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
mock_client = AsyncMock()
|
||||
mock_litellm_usertable.upsert = mock_client
|
||||
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||
|
||||
try:
|
||||
await team_member_add(
|
||||
data=team_member_add_request,
|
||||
user_api_key_dict=valid_token,
|
||||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
)
|
||||
except HTTPException as e:
|
||||
if user_role == "user":
|
||||
assert e.status_code == 403
|
||||
else:
|
||||
raise e
|
||||
|
||||
mock_client.assert_called()
|
||||
|
||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||
|
||||
assert (
|
||||
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||
== litellm.max_internal_user_budget
|
||||
)
|
||||
assert (
|
||||
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||
== litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_info_team_list(prisma_client):
|
||||
"""Assert user_info for admin calls team_list function"""
|
||||
|
|
24
litellm/types/llms/ollama.py
Normal file
24
litellm/types/llms/ollama.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import json
|
||||
from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import (
|
||||
Protocol,
|
||||
Required,
|
||||
Self,
|
||||
TypeGuard,
|
||||
get_origin,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
|
||||
class OllamaToolCallFunction(
|
||||
TypedDict
|
||||
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
|
||||
class OllamaToolCall(TypedDict):
|
||||
function: OllamaToolCallFunction
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from typing_extensions import (
|
||||
Protocol,
|
||||
|
@ -305,3 +305,18 @@ class ResponseTuningJob(TypedDict):
|
|||
]
|
||||
createTime: Optional[str]
|
||||
updateTime: Optional[str]
|
||||
|
||||
|
||||
class InstanceVideo(TypedDict, total=False):
|
||||
gcsUri: str
|
||||
videoSegmentConfig: Tuple[float, float, float]
|
||||
|
||||
|
||||
class Instance(TypedDict, total=False):
|
||||
text: str
|
||||
image: Dict[str, str]
|
||||
video: InstanceVideo
|
||||
|
||||
|
||||
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||
instances: List[Instance]
|
||||
|
|
|
@ -1116,6 +1116,7 @@ all_litellm_params = [
|
|||
"cooldown_time",
|
||||
"cache_key",
|
||||
"max_retries",
|
||||
"user_continue_message",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1218,6 +1219,7 @@ class StandardLoggingPayload(TypedDict):
|
|||
metadata: StandardLoggingMetadata
|
||||
cache_hit: Optional[bool]
|
||||
cache_key: Optional[str]
|
||||
saved_cache_cost: Optional[float]
|
||||
request_tags: list
|
||||
end_user: Optional[str]
|
||||
requester_ip_address: Optional[str]
|
||||
|
|
|
@ -541,7 +541,7 @@ def function_setup(
|
|||
call_type == CallTypes.embedding.value
|
||||
or call_type == CallTypes.aembedding.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
messages = args[1] if len(args) > 1 else kwargs.get("input", None)
|
||||
elif (
|
||||
call_type == CallTypes.image_generation.value
|
||||
or call_type == CallTypes.aimage_generation.value
|
||||
|
@ -2323,6 +2323,7 @@ def get_litellm_params(
|
|||
output_cost_per_second=None,
|
||||
cooldown_time=None,
|
||||
text_completion=None,
|
||||
user_continue_message=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2347,6 +2348,7 @@ def get_litellm_params(
|
|||
"output_cost_per_second": output_cost_per_second,
|
||||
"cooldown_time": cooldown_time,
|
||||
"text_completion": text_completion,
|
||||
"user_continue_message": user_continue_message,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -3145,7 +3147,6 @@ def get_optional_params(
|
|||
or model in litellm.vertex_embedding_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -3157,9 +3158,8 @@ def get_optional_params(
|
|||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
if litellm.vertex_ai_safety_settings is not None:
|
||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||
elif custom_llm_provider == "gemini":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -3170,7 +3170,7 @@ def get_optional_params(
|
|||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
@ -3185,6 +3185,8 @@ def get_optional_params(
|
|||
else False
|
||||
),
|
||||
)
|
||||
if litellm.vertex_ai_safety_settings is not None:
|
||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||
):
|
||||
|
@ -4219,6 +4221,7 @@ def get_supported_openai_params(
|
|||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
|
@ -4233,6 +4236,7 @@ def get_supported_openai_params(
|
|||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
|
@ -7121,6 +7125,14 @@ def exception_type(
|
|||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "A conversation must start with a user message." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
||||
model=model,
|
||||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"Unable to locate credentials" in error_str
|
||||
or "The security token included in the request is invalid"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.44.1"
|
||||
version = "1.44.2"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.44.1"
|
||||
version = "1.44.2"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue