Merge branch 'BerriAI:main' into main

This commit is contained in:
Hannes Burrichter 2024-05-11 18:28:16 +02:00 committed by GitHub
commit d0493248f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2206 additions and 551 deletions

View file

@ -198,6 +198,7 @@ jobs:
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e AUTO_INFER_REGION=True \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \ -e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \ -e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \

View file

@ -17,6 +17,14 @@ This covers:
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md) - ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
## [COMING SOON] AWS Marketplace Support
Deploy managed LiteLLM Proxy within your VPC.
Includes all enterprise features.
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## Frequently Asked Questions ## Frequently Asked Questions
### What topics does Professional support cover and what SLAs do you offer? ### What topics does Professional support cover and what SLAs do you offer?

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs> <Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)"> <TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem> </TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)"> <TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/facebook/blenderbot-400M-distill", model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages, messages=messages,
api_base="https://my-endpoint.huggingface.cloud" api_base="https://my-endpoint.huggingface.cloud"
) )
@ -62,7 +116,123 @@ response = completion(
print(response) print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs"> <TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python ```python
import os import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/roneneldan/TinyStories-3M", model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages, messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
) )

View file

@ -0,0 +1,247 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🆕 Predibase
LiteLLM supports all models on Predibase
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### API KEYS
```python
import os
os.environ["PREDIBASE_API_KEY"] = ""
```
### Example Call
```python
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
# predibase llama-3 call
response = completion(
model="predibase/llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="llama-3",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Advanced Usage - Prompt Formatting
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
To apply a custom prompt template:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
import os
os.environ["PREDIBASE_API_KEY"] = ""
# Create your own custom prompt template
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
}
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
def predibase_custom_model():
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages)
print(response['choices'][0]['message']['content'])
return response
predibase_custom_model()
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
# Model-specific parameters
model_list:
- model_name: mistral-7b # model alias
litellm_params: # actual params for litellm.completion()
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
api_key: os.environ/PREDIBASE_API_KEY
initial_prompt_value: "\n"
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
final_prompt_value: "\n"
bos_token: "<s>"
eos_token: "</s>"
max_tokens: 4096
```
</TabItem>
</Tabs>
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
# predibae llama-3 call
response = completion(
model="predibase/llama3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
max_tokens: 20
temperature: 0.5
```
## Passings Predibase specific params - adapter_id, adapter_source,
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
# predibase llama3 call
response = completion(
model="predibase/llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
adapter_id="my_repo/3",
adapter_soruce="pbase",
)
```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
adapter_id: my_repo/3
adapter_source: pbase
```

View file

@ -0,0 +1,95 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Triton Inference Server
LiteLLM supports Embedding Models on Triton Inference Servers
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### Example Call
Use the `triton/` prefix to route to triton server
```python
from litellm import embedding
import os
response = await litellm.aembedding(
model="triton/<your-triton-model>",
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
input=["good morning from litellm"],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/embeddings
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
input=["hello from litellm"],
model="my-triton-model"
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/embeddings' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"input": ["write a litellm poem"]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

View file

@ -132,6 +132,9 @@ const sidebars = {
"providers/cohere", "providers/cohere",
"providers/anyscale", "providers/anyscale",
"providers/huggingface", "providers/huggingface",
"providers/watsonx",
"providers/predibase",
"providers/triton-inference-server",
"providers/ollama", "providers/ollama",
"providers/perplexity", "providers/perplexity",
"providers/groq", "providers/groq",
@ -151,7 +154,7 @@ const sidebars = {
"providers/openrouter", "providers/openrouter",
"providers/custom_openai_proxy", "providers/custom_openai_proxy",
"providers/petals", "providers/petals",
"providers/watsonx",
], ],
}, },
"proxy/custom_pricing", "proxy/custom_pricing",

View file

@ -1,5 +1,6 @@
### Hide pydantic namespace conflict warnings globally ### ### Hide pydantic namespace conflict warnings globally ###
import warnings import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
@ -537,6 +538,7 @@ provider_list: List = [
"xinference", "xinference",
"fireworks_ai", "fireworks_ai",
"watsonx", "watsonx",
"triton",
"predibase", "predibase",
"custom", # custom apis "custom", # custom apis
] ]

View file

@ -262,7 +262,23 @@ class LangFuseLogger:
try: try:
tags = [] tags = []
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata try:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, dict)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
metadata = new_metadata
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3") supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
@ -346,6 +362,7 @@ class LangFuseLogger:
"version": clean_metadata.pop( "version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None) "trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
"user_id": user_id,
} }
for key in list( for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) filter(lambda key: key.startswith("trace_"), clean_metadata.keys())

View file

@ -4,7 +4,6 @@ from datetime import datetime, timezone
import traceback import traceback
import dotenv import dotenv
import importlib import importlib
import sys
import packaging import packaging
@ -18,13 +17,33 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_tool_calls(tool_calls):
if tool_calls is None:
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}
return serialized
return [clean_tool_call(tool_call) for tool_call in tool_calls]
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
def clean_message(message): def clean_message(message):
# if is strin, return as is # if is string, return as is
if isinstance(message, str): if isinstance(message, str):
return message return message
@ -38,9 +57,7 @@ def parse_messages(input):
# Only add tool_calls and function_call to res if they are set # Only add tool_calls and function_call to res if they are set
if message.get("tool_calls"): if message.get("tool_calls"):
serialized["tool_calls"] = message.get("tool_calls") serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
if message.get("function_call"):
serialized["function_call"] = message.get("function_call")
return serialized return serialized
@ -93,8 +110,13 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
optional_params = kwargs.get("optional_params", {})
metadata = litellm_params.get("metadata", {}) or {} metadata = litellm_params.get("metadata", {}) or {}
if optional_params:
# merge into extra
extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or [] tags = litellm_params.pop("tags", None) or []
if extra: if extra:
@ -104,7 +126,7 @@ class LunaryLogger:
# keep only serializable types # keep only serializable types
for param, value in extra.items(): for param, value in extra.items():
if not isinstance(value, (str, int, bool, float)): if not isinstance(value, (str, int, bool, float)) and param != "tools":
try: try:
extra[param] = str(value) extra[param] = str(value)
except: except:
@ -140,7 +162,7 @@ class LunaryLogger:
metadata=metadata, metadata=metadata,
runtime="litellm", runtime="litellm",
tags=tags, tags=tags,
extra=extra, params=extra,
) )
self.lunary_client.track_event( self.lunary_client.track_event(

View file

@ -8,6 +8,7 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse, TranscriptionResponse,
get_secret,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig from litellm import OpenAIConfig
@ -16,6 +17,7 @@ import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid import uuid
import os
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
if azure_client_id is None or azure_tenant is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
headers["api-key"] = api_key headers["api-key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True: if acompletion is True:
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client # setting Azure client
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
## LOGGING ## LOGGING
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True: if aimg_generation == True:
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None: if max_retries is not None:

View file

@ -551,6 +551,7 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
@ -567,6 +568,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
] ]
# Iterate over parameters and update if needed # Iterate over parameters and update if needed
@ -582,6 +584,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
) = params_to_check ) = params_to_check
### SET REGION NAME ### SET REGION NAME
@ -620,7 +623,38 @@ def init_bedrock_client(
config = boto3.session.Config() config = boto3.session.Config()
### CHECK STS ### ### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None: if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in # use sts if role name passed in
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -755,6 +789,7 @@ def completion(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop("aws_bedrock_client", None) client = optional_params.pop("aws_bedrock_client", None)
@ -769,6 +804,7 @@ def completion(
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name, aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
extra_headers=extra_headers, extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -1291,6 +1327,7 @@ def embedding(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1298,6 +1335,7 @@ def embedding(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
) )
@ -1380,6 +1418,7 @@ def image_generation(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1387,6 +1426,7 @@ def image_generation(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
timeout=timeout, timeout=timeout,

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -162,16 +227,18 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference"
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational"
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return "text-generation"
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
self, self,
completion_response, completion_response,
model_response, model_response,
task, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response["choices"][0]["message"]["content"] = output_parser(
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = "" input_text = ""
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, }
"stream": ( # type: ignore if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False
), )
}
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict) isinstance(completion_response, dict)
and "error" in completion_response and "error" in completion_response
): ):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
message=completion_response["error"], message=completion_response["error"], # type: ignore
status_code=response.status_code, status_code=response.status_code,
) )
return self.convert_to_model_response_object( return self.convert_to_model_response_object(
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
task: str, task: hf_tasks,
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,

119
litellm/llms/triton.py Normal file
View file

@ -0,0 +1,119 @@
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TritonError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
async def aembedding(
self,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj=None,
api_key: Optional[str] = None,
):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_outputs = _json_response["outputs"]
_output_data = _outputs[0]["data"]
_embedding_output = {
"object": "embedding",
"index": 0,
"embedding": _output_data,
}
model_response.model = _json_response.get("model_name", "None")
model_response.data = [_embedding_output]
return model_response
def embedding(
self,
model: str,
input: list,
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
):
data_for_triton = {
"inputs": [
{
"name": "input_text",
"shape": [1],
"datatype": "BYTES",
"data": input,
}
]
}
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call(
input="",
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": curl_string,
},
)
if aembedding == True:
response = self.aembedding(
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
return response
else:
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)

View file

@ -419,6 +419,7 @@ def completion(
from google.protobuf.struct_pb2 import Value # type: ignore from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore import google.auth # type: ignore
import proto # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -605,9 +606,21 @@ def completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
args_str = json.dumps(args_dict) for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
@ -810,6 +823,8 @@ def completion(
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
except Exception as e: except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))

View file

@ -14,7 +14,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
@ -47,6 +46,7 @@ from .llms import (
ai21, ai21,
sagemaker, sagemaker,
bedrock, bedrock,
triton,
huggingface_restapi, huggingface_restapi,
replicate, replicate,
aleph_alpha, aleph_alpha,
@ -75,6 +75,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -112,6 +113,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -662,6 +664,7 @@ def completion(
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
@ -2621,6 +2624,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "voyage" or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter" or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
@ -2954,23 +2958,43 @@ def embedding(
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "triton":
if api_base is None:
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret("VERTEXAI_PROJECT")
or get_secret("VERTEX_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.pop("vertex_location", None) optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None) or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
) )
vertex_credentials = ( vertex_credentials = (
optional_params.pop("vertex_credentials", None) optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
or get_secret("VERTEX_CREDENTIALS")
) )
response = vertex_ai.embedding( response = vertex_ai.embedding(

View file

@ -20,22 +20,20 @@ model_list:
- litellm_params: - litellm_params:
model: together_ai/codellama/CodeLlama-13b-Instruct-hf model: together_ai/codellama/CodeLlama-13b-Instruct-hf
model_name: CodeLlama-13b-Instruct model_name: CodeLlama-13b-Instruct
router_settings:
num_retries: 0
enable_pre_call_checks: true
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
router_settings: router_settings:
routing_strategy: "latency-based-routing" redis_host: redis
# redis_password: <your redis password>
redis_port: 6379
litellm_settings: litellm_settings:
success_callback: ["langfuse"] set_verbose: True
# service_callback: ["prometheus_system"]
# success_callback: ["prometheus"]
# failure_callback: ["prometheus"]
general_settings: general_settings:
alerting: ["slack"] enable_jwt_auth: True
alert_types: ["llm_exceptions", "daily_reports"] disable_reset_budget: True
alerting_args: proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
daily_report_frequency: 60 # every minute routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
report_check_interval: 5 # every 5s

View file

@ -156,6 +156,11 @@ class JWTHandler:
return public_key return public_key
async def auth_jwt(self, token: str) -> dict: async def auth_jwt(self, token: str) -> dict:
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
# the key in different ways (e.g. HS* and RS*)."
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
audience = os.getenv("JWT_AUDIENCE") audience = os.getenv("JWT_AUDIENCE")
decode_options = None decode_options = None
if audience is None: if audience is None:
@ -189,7 +194,7 @@ class JWTHandler:
payload = jwt.decode( payload = jwt.decode(
token, token,
public_key_rsa, # type: ignore public_key_rsa, # type: ignore
algorithms=["RS256"], algorithms=algorithms,
options=decode_options, options=decode_options,
audience=audience, audience=audience,
) )
@ -214,7 +219,7 @@ class JWTHandler:
payload = jwt.decode( payload = jwt.decode(
token, token,
key, key,
algorithms=["RS256"], algorithms=algorithms,
audience=audience, audience=audience,
options=decode_options options=decode_options
) )

View file

@ -8,7 +8,10 @@ model_list:
litellm_params: litellm_params:
model: openai/* model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: my-triton-model
litellm_params:
model: triton/any"
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
general_settings: general_settings:
store_model_in_db: true store_model_in_db: true
@ -17,4 +20,10 @@ general_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"] failure_callback: ["langfuse"]
default_team_settings:
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"

View file

@ -7795,11 +7795,15 @@ async def update_model(
) )
async def model_info_v2( async def model_info_v2(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
model: Optional[str] = fastapi.Query(
None, description="Specify the model name (optional)"
),
debug: Optional[bool] = False,
): ):
""" """
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now. BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
""" """
global llm_model_list, general_settings, user_config_file_path, proxy_config global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
if llm_model_list is None or not isinstance(llm_model_list, list): if llm_model_list is None or not isinstance(llm_model_list, list):
raise HTTPException( raise HTTPException(
@ -7822,19 +7826,35 @@ async def model_info_v2(
if len(user_api_key_dict.models) > 0: if len(user_api_key_dict.models) > 0:
user_models = user_api_key_dict.models user_models = user_api_key_dict.models
if model is not None:
all_models = [m for m in all_models if m["model_name"] == model]
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json # fill in model info based on config.yaml and litellm model_prices_and_context_window.json
for model in all_models: for _model in all_models:
# provided model_info in config.yaml # provided model_info in config.yaml
model_info = model.get("model_info", {}) model_info = _model.get("model_info", {})
if debug == True:
_openai_client = "None"
if llm_router is not None:
_openai_client = (
llm_router._get_client(
deployment=_model, kwargs={}, client_type="async"
)
or "None"
)
else:
_openai_client = "llm_router_is_None"
openai_client = str(_openai_client)
_model["openai_client"] = openai_client
# read litellm model_prices_and_context_window.json to get the following: # read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens # input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model) litellm_model_info = get_litellm_model_info(model=_model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map # 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}: if litellm_model_info == {}:
# use litellm_param model_name to get model_info # use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {}) litellm_params = _model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None) litellm_model = litellm_params.get("model", None)
try: try:
litellm_model_info = litellm.get_model_info(model=litellm_model) litellm_model_info = litellm.get_model_info(model=litellm_model)
@ -7843,7 +7863,7 @@ async def model_info_v2(
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map # 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}: if litellm_model_info == {}:
# use litellm_param model_name to get model_info # use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {}) litellm_params = _model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None) litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/") split_model = litellm_model.split("/")
if len(split_model) > 0: if len(split_model) > 0:
@ -7855,10 +7875,10 @@ async def model_info_v2(
for k, v in litellm_model_info.items(): for k, v in litellm_model_info.items():
if k not in model_info: if k not in model_info:
model_info[k] = v model_info[k] = v
model["model_info"] = model_info _model["model_info"] = model_info
# don't return the api key / vertex credentials # don't return the api key / vertex credentials
model["litellm_params"].pop("api_key", None) _model["litellm_params"].pop("api_key", None)
model["litellm_params"].pop("vertex_credentials", None) _model["litellm_params"].pop("vertex_credentials", None)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}

View file

@ -9,7 +9,7 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -48,6 +48,7 @@ from litellm.types.router import (
AlertingConfig, AlertingConfig,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
class Router: class Router:
@ -102,6 +103,7 @@ class Router:
"usage-based-routing", "usage-based-routing",
"latency-based-routing", "latency-based-routing",
"cost-based-routing", "cost-based-routing",
"usage-based-routing-v2",
] = "simple-shuffle", ] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None, semaphore: Optional[asyncio.Semaphore] = None,
@ -2114,6 +2116,10 @@ class Router:
raise ValueError( raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
) )
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None: if api_version is None:
api_version = "2023-07-01-preview" api_version = "2023-07-01-preview"
@ -2125,6 +2131,7 @@ class Router:
cache_key = f"{model_id}_async_client" cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( _client = openai.AsyncAzureOpenAI(
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
@ -2149,6 +2156,7 @@ class Router:
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
@ -2173,6 +2181,7 @@ class Router:
cache_key = f"{model_id}_stream_async_client" cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore _client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
@ -2197,6 +2206,7 @@ class Router:
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
@ -2229,6 +2239,7 @@ class Router:
"api_key": api_key, "api_key": api_key,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"api_version": api_version, "api_version": api_version,
"azure_ad_token": azure_ad_token,
} }
from litellm.llms.azure import select_azure_base_url_or_endpoint from litellm.llms.azure import select_azure_base_url_or_endpoint
@ -2557,20 +2568,27 @@ class Router:
self.set_client(model=deployment.to_json(exclude_none=True)) self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model) # set region (if azure model)
try: _auto_infer_region = os.environ.get("AUTO_INFER_REGION", False)
if "azure" in deployment.litellm_params.model: if _auto_infer_region == True or _auto_infer_region == "True":
region = litellm.utils.get_model_region( print("Auto inferring region") # noqa
litellm_params=deployment.litellm_params, mode=None """
) Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up
"""
try:
if "azure" in deployment.litellm_params.model:
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region deployment.litellm_params.region_name = region
except Exception as e: except Exception as e:
verbose_router_logger.error( verbose_router_logger.error(
"Unable to get the region for azure model - {}, {}".format( "Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e) deployment.litellm_params.model, str(e)
)
) )
) pass # [NON-BLOCKING]
pass # [NON-BLOCKING]
return deployment return deployment
@ -2599,7 +2617,7 @@ class Router:
self.model_names.append(deployment.model_name) self.model_names.append(deployment.model_name)
return deployment return deployment
def upsert_deployment(self, deployment: Deployment) -> Deployment: def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
""" """
Add or update deployment Add or update deployment
Parameters: Parameters:
@ -2609,8 +2627,17 @@ class Router:
- The added/updated deployment - The added/updated deployment
""" """
# check if deployment already exists # check if deployment already exists
_deployment_model_id = deployment.model_info.id or ""
_deployment_on_router: Optional[Deployment] = self.get_deployment(
model_id=_deployment_model_id
)
if _deployment_on_router is not None:
# deployment with this model_id exists on the router
if deployment.litellm_params == _deployment_on_router.litellm_params:
# No need to update
return None
if deployment.model_info.id in self.get_model_ids(): # if there is a new litellm param -> then update the deployment
# remove the previous deployment # remove the previous deployment
removal_idx: Optional[int] = None removal_idx: Optional[int] = None
for idx, model in enumerate(self.model_list): for idx, model in enumerate(self.model_list):
@ -2619,16 +2646,9 @@ class Router:
if removal_idx is not None: if removal_idx is not None:
self.model_list.pop(removal_idx) self.model_list.pop(removal_idx)
else:
# add to model list # if the model_id is not in router
_deployment = deployment.to_json(exclude_none=True) self.add_deployment(deployment=deployment)
self.model_list.append(_deployment)
# initialize client
self._add_deployment(deployment=deployment)
# add to model names
self.model_names.append(deployment.model_name)
return deployment return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]: def delete_deployment(self, id: str) -> Optional[Deployment]:
@ -2989,11 +3009,15 @@ class Router:
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None, input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False, specific_deployment: Optional[bool] = False,
): ) -> Tuple[str, Union[list, dict]]:
""" """
Common checks for 'get_available_deployment' across sync + async call. Common checks for 'get_available_deployment' across sync + async call.
If 'healthy_deployments' returned is None, this means the user chose a specific deployment If 'healthy_deployments' returned is None, this means the user chose a specific deployment
Returns
- Dict, if specific model chosen
- List, if multiple models chosen
""" """
# check if aliases set on litellm model alias map # check if aliases set on litellm model alias map
if specific_deployment == True: if specific_deployment == True:
@ -3003,7 +3027,7 @@ class Router:
if deployment_model == model: if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name # return the first deployment where the `model` matches the specificed deployment name
return deployment, None return deployment_model, deployment
raise ValueError( raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
) )
@ -3019,7 +3043,7 @@ class Router:
self.default_deployment self.default_deployment
) # self.default_deployment ) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model updated_deployment["litellm_params"]["model"] = model
return updated_deployment, None return model, updated_deployment
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
@ -3072,10 +3096,10 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
specific_deployment=specific_deployment, specific_deployment=specific_deployment,
) ) # type: ignore
if healthy_deployments is None: if isinstance(healthy_deployments, dict):
return model return healthy_deployments
# filter out the deployments currently cooling down # filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
@ -3131,7 +3155,7 @@ class Router:
): ):
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments( deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3141,7 +3165,7 @@ class Router:
): ):
deployment = await self.lowestcost_logger.async_get_available_deployments( deployment = await self.lowestcost_logger.async_get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3219,8 +3243,8 @@ class Router:
specific_deployment=specific_deployment, specific_deployment=specific_deployment,
) )
if healthy_deployments is None: if isinstance(healthy_deployments, dict):
return model return healthy_deployments
# filter out the deployments currently cooling down # filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
@ -3244,7 +3268,7 @@ class Router:
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments( deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments model_group=model, healthy_deployments=healthy_deployments # type: ignore
) )
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
@ -3292,7 +3316,7 @@ class Router:
): ):
deployment = self.lowestlatency_logger.get_available_deployments( deployment = self.lowestlatency_logger.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
elif ( elif (
@ -3301,7 +3325,7 @@ class Router:
): ):
deployment = self.lowesttpm_logger.get_available_deployments( deployment = self.lowesttpm_logger.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )
@ -3311,7 +3335,7 @@ class Router:
): ):
deployment = self.lowesttpm_logger_v2.get_available_deployments( deployment = self.lowesttpm_logger_v2.get_available_deployments(
model_group=model, model_group=model,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments, # type: ignore
messages=messages, messages=messages,
input=input, input=input,
) )

View file

@ -113,6 +113,49 @@ async def get_response():
], ],
) )
return response return response
except litellm.UnprocessableEntityError as e:
pass
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")
@pytest.mark.asyncio
async def test_get_router_response():
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
json_obj = get_vertex_ai_creds_json()
vertex_credentials = json.dumps(json_obj)
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try:
router = litellm.Router(
model_list=[
{
"model_name": "sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-sonnet@20240229",
"vertex_ai_project": vertex_ai_project,
"vertex_ai_location": vertex_ai_location,
"vertex_credentials": vertex_credentials,
},
}
]
)
response = await router.acompletion(
model="sonnet",
messages=[
{
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
print(f"\n\nResponse: {response}\n\n")
except litellm.UnprocessableEntityError as e: except litellm.UnprocessableEntityError as e:
pass pass
except Exception as e: except Exception as e:
@ -547,47 +590,37 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
def test_gemini_pro_function_calling(): def test_gemini_pro_function_calling():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ response = litellm.completion(
{ model="vertex_ai/gemini-pro",
"type": "function", messages=[
"function": { {
"name": "get_current_weather", "role": "user",
"description": "Get the current weather in a given location", "content": "Call the submit_cities function with San Francisco and New York",
"parameters": { }
"type": "object", ],
"properties": { tools=[
"location": { {
"type": "string", "type": "function",
"description": "The city and state, e.g. San Francisco, CA", "function": {
}, "name": "submit_cities",
"unit": { "description": "Submits a list of cities",
"type": "string", "parameters": {
"enum": ["celsius", "fahrenheit"], "type": "object",
"properties": {
"cities": {"type": "array", "items": {"type": "string"}}
}, },
"required": ["cities"],
}, },
"required": ["location"],
}, },
}, }
} ],
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}")
# assert completion.choices[0].message.content is None ## GEMINI PRO is very chatty. print(f"response: {response}")
if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
completion.choices[0].message.tool_calls, list
):
assert len(completion.choices[0].message.tool_calls) == 1
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
@ -596,7 +629,7 @@ def test_gemini_pro_function_calling():
if "429 Quota exceeded" in str(e): if "429 Quota exceeded" in str(e):
pass pass
else: else:
return pytest.fail("An unexpected exception occurred - {}".format(str(e)))
# gemini_pro_function_calling() # gemini_pro_function_calling()

View file

@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth() # test_completion_bedrock_claude_sts_client_auth()
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set")
def test_completion_bedrock_claude_sts_oidc_auth():
print("\ncalling bedrock claude with oidc auth")
import os
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
try:
litellm.set_verbose = True
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_extra_headers(): def test_bedrock_extra_headers():
try: try:

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -96,7 +97,6 @@ async def test_completion_predibase(sync_mode):
response = completion( response = completion(
model="predibase/llama-3-8b-instruct", model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95", tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"), api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}], messages=[{"role": "user", "content": "What is the meaning of life?"}],
) )
@ -1138,7 +1138,7 @@ def test_get_hf_task_for_model():
model = "roneneldan/TinyStories-3M" model = "roneneldan/TinyStories-3M"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == None assert model_type == "text-generation"
# test_get_hf_task_for_model() # test_get_hf_task_for_model()
@ -1146,15 +1146,92 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ######################## # ################### Hugging Face TGI models ########################
# # TGI model # # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def hf_test_completion_tgi(): def tgi_mock_post(url, data=None, json=None, headers=None):
# litellm.set_verbose=True mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try: try:
response = completion( with patch("requests.post", side_effect=tgi_mock_post):
model="huggingface/HuggingFaceH4/zephyr-7b-beta", response = completion(
messages=[{"content": "Hello, how are you?", "role": "user"}], model="huggingface/HuggingFaceH4/zephyr-7b-beta",
) messages=[{"content": "Hello, how are you?", "role": "user"}],
# Add any assertions here to check the response max_tokens=10,
print(response) )
# Add any assertions here to check the response
print(response)
except litellm.ServiceUnavailableError as e: except litellm.ServiceUnavailableError as e:
pass pass
except Exception as e: except Exception as e:
@ -1192,6 +1269,40 @@ def hf_test_completion_tgi():
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task() # hf_test_completion_none_task()
def mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
[
{"label": "LABEL_0", "score": 0.9990691542625427},
{"label": "LABEL_1", "score": 0.0009308889275416732},
]
]
return mock_response
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
########################### End of Hugging Face Tests ############################################## ########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api(): # def test_completion_hf_api():
# # failing on circle ci commenting out # # failing on circle ci commenting out

View file

@ -437,8 +437,9 @@ async def test_cost_tracking_with_caching():
max_tokens=40, max_tokens=40,
temperature=0.2, temperature=0.2,
caching=True, caching=True,
mock_response="Hey, i'm doing well!",
) )
await asyncio.sleep(1) # success callback is async await asyncio.sleep(3) # success callback is async
response_cost = customHandler_optional_params.response_cost response_cost = customHandler_optional_params.response_cost
assert response_cost > 0 assert response_cost > 0
response2 = await litellm.acompletion( response2 = await litellm.acompletion(

View file

@ -516,6 +516,23 @@ def test_voyage_embeddings():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_triton_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="triton/my-triton-model",
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_voyage_embeddings() # test_voyage_embeddings()
# def test_xinference_embeddings(): # def test_xinference_embeddings():
# try: # try:

View file

@ -3,7 +3,27 @@ from litellm import get_optional_params
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
optional_params = get_optional_params( optional_params = get_optional_params(
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], model="",
tool_choice= 'auto', tools=[
{
"type": "function",
"function": {
"description": "Get the current weather in a given location",
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
],
tool_choice="auto",
) )
assert optional_params is not None assert optional_params is not None

View file

@ -11,7 +11,6 @@ litellm.failure_callback = ["lunary"]
litellm.success_callback = ["lunary"] litellm.success_callback = ["lunary"]
litellm.set_verbose = True litellm.set_verbose = True
def test_lunary_logging(): def test_lunary_logging():
try: try:
response = completion( response = completion(
@ -59,9 +58,46 @@ def test_lunary_logging_with_metadata():
except Exception as e: except Exception as e:
print(e) print(e)
#test_lunary_logging_with_metadata()
# test_lunary_logging_with_metadata() def test_lunary_with_tools():
import litellm
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
response = litellm.completion(
model="gpt-3.5-turbo-1106",
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
response_message = response.choices[0].message
print("\nLLM Response:\n", response.choices[0].message)
#test_lunary_with_tools()
def test_lunary_logging_with_streaming_and_metadata(): def test_lunary_logging_with_streaming_and_metadata():
try: try:

View file

@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
def test_azure_gpt_optional_params_gpt_vision(): def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="azure", custom_llm_provider="azure",
max_tokens=10, max_tokens=10,
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
def test_azure_gpt_optional_params_gpt_vision_with_extra_body(): def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python # if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="azure", custom_llm_provider="azure",
max_tokens=10, max_tokens=10,
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
def test_openai_extra_headers(): def test_openai_extra_headers():
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="openai", custom_llm_provider="openai",
max_tokens=10, max_tokens=10,

View file

@ -754,6 +754,9 @@ async def test_async_fallbacks_max_retries_per_request():
def test_ausage_based_routing_fallbacks(): def test_ausage_based_routing_fallbacks():
try: try:
import litellm
litellm.set_verbose = False
# [Prod Test] # [Prod Test]
# IT tests Usage Based Routing with fallbacks # IT tests Usage Based Routing with fallbacks
# The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4" # The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4"
@ -766,10 +769,10 @@ def test_ausage_based_routing_fallbacks():
load_dotenv() load_dotenv()
# Constants for TPM and RPM allocation # Constants for TPM and RPM allocation
AZURE_FAST_RPM = 0 AZURE_FAST_RPM = 1
AZURE_BASIC_RPM = 0 AZURE_BASIC_RPM = 1
OPENAI_RPM = 0 OPENAI_RPM = 0
ANTHROPIC_RPM = 2 ANTHROPIC_RPM = 10
def get_azure_params(deployment_name: str): def get_azure_params(deployment_name: str):
params = { params = {
@ -832,9 +835,9 @@ def test_ausage_based_routing_fallbacks():
fallbacks=fallbacks_list, fallbacks=fallbacks_list,
set_verbose=True, set_verbose=True,
debug_level="DEBUG", debug_level="DEBUG",
routing_strategy="usage-based-routing", routing_strategy="usage-based-routing-v2",
redis_host=os.environ["REDIS_HOST"], redis_host=os.environ["REDIS_HOST"],
redis_port=os.environ["REDIS_PORT"], redis_port=int(os.environ["REDIS_PORT"]),
num_retries=0, num_retries=0,
) )
@ -853,8 +856,8 @@ def test_ausage_based_routing_fallbacks():
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
assert response._hidden_params["model_id"] == "1" assert response._hidden_params["model_id"] == "1"
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2 for i in range(10):
for i in range(3): # now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
response = router.completion( response = router.completion(
model="azure/gpt-4-fast", model="azure/gpt-4-fast",
messages=messages, messages=messages,
@ -863,8 +866,7 @@ def test_ausage_based_routing_fallbacks():
) )
print("response: ", response) print("response: ", response)
print("response._hidden_params: ", response._hidden_params) print("response._hidden_params: ", response._hidden_params)
if i == 2: if i == 9:
# by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2
assert response._hidden_params["model_id"] == "4" assert response._hidden_params["model_id"] == "4"
except Exception as e: except Exception as e:

View file

@ -23,3 +23,36 @@ def test_aws_secret_manager():
print(f"secret_val: {secret_val}") print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234" assert secret_val == "sk-1234"
def redact_oidc_signature(secret_val):
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run")
def test_oidc_google():
secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions")
def test_oidc_github():
secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner")
def test_oidc_circleci():
secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner")
def test_oidc_circleci_v2():
secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")

View file

@ -33,6 +33,9 @@ from dataclasses import (
) )
import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache
oidc_cache = DualCache()
try: try:
# this works in python 3.8 # this works in python 3.8
@ -1079,6 +1082,7 @@ class Logging:
litellm_call_id, litellm_call_id,
function_id, function_id,
dynamic_success_callbacks=None, dynamic_success_callbacks=None,
dynamic_failure_callbacks=None,
dynamic_async_success_callbacks=None, dynamic_async_success_callbacks=None,
langfuse_public_key=None, langfuse_public_key=None,
langfuse_secret=None, langfuse_secret=None,
@ -1113,7 +1117,7 @@ class Logging:
self.sync_streaming_chunks = [] # for generating complete stream response self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {} self.model_call_details = {}
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call self.dynamic_failure_callbacks = dynamic_failure_callbacks
self.dynamic_success_callbacks = ( self.dynamic_success_callbacks = (
dynamic_success_callbacks # callbacks set for just that call dynamic_success_callbacks # callbacks set for just that call
) )
@ -2334,11 +2338,26 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
callbacks = [] # init this to empty incase it's not created
if self.dynamic_failure_callbacks is not None and isinstance(
self.dynamic_failure_callbacks, list
):
callbacks = self.dynamic_failure_callbacks
## keep the internal functions ##
for callback in litellm.failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.failure_callback
result = None # result sent to all loggers, init this to None incase it's not created result = None # result sent to all loggers, init this to None incase it's not created
self.redact_message_input_output_from_logging(result=result) self.redact_message_input_output_from_logging(result=result)
for callback in litellm.failure_callback: for callback in callbacks:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!") print_verbose("reaches lite_debugger for logging!")
@ -2427,7 +2446,7 @@ class Logging:
) )
elif callback == "langfuse": elif callback == "langfuse":
global langFuseLogger global langFuseLogger
verbose_logger.debug("reaches langfuse for logging!") verbose_logger.debug("reaches langfuse for logging failure")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
if ( if (
@ -2436,8 +2455,16 @@ class Logging:
kwargs[k] = v kwargs[k] = v
# this only logs streaming once, complete_streaming_response exists i.e when stream ends # this only logs streaming once, complete_streaming_response exists i.e when stream ends
if langFuseLogger is None or ( if langFuseLogger is None or (
self.langfuse_public_key != langFuseLogger.public_key (
and self.langfuse_secret != langFuseLogger.secret_key self.langfuse_public_key is not None
and self.langfuse_public_key
!= langFuseLogger.public_key
)
and (
self.langfuse_public_key is not None
and self.langfuse_public_key
!= langFuseLogger.public_key
)
): ):
langFuseLogger = LangFuseLogger( langFuseLogger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key, langfuse_public_key=self.langfuse_public_key,
@ -2713,6 +2740,7 @@ def function_setup(
### DYNAMIC CALLBACKS ### ### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = None dynamic_success_callbacks = None
dynamic_async_success_callbacks = None dynamic_async_success_callbacks = None
dynamic_failure_callbacks = None
if kwargs.get("success_callback", None) is not None and isinstance( if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list kwargs["success_callback"], list
): ):
@ -2734,6 +2762,10 @@ def function_setup(
for index in reversed(removed_async_items): for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index) kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs.pop("success_callback") dynamic_success_callbacks = kwargs.pop("success_callback")
if kwargs.get("failure_callback", None) is not None and isinstance(
kwargs["failure_callback"], list
):
dynamic_failure_callbacks = kwargs.pop("failure_callback")
if add_breadcrumb: if add_breadcrumb:
try: try:
@ -2816,9 +2848,11 @@ def function_setup(
call_type=call_type, call_type=call_type,
start_time=start_time, start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks, dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_failure_callbacks=dynamic_failure_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks, dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None), langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None), langfuse_secret=kwargs.pop("langfuse_secret", None)
or kwargs.pop("langfuse_secret_key", None),
) )
## check if metadata is passed in ## check if metadata is passed in
litellm_params = {"api_base": ""} litellm_params = {"api_base": ""}
@ -4783,6 +4817,12 @@ def get_optional_params_embeddings(
status_code=500, status_code=500,
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.", message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
) )
if custom_llm_provider == "triton":
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
final_params = {**non_default_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0: if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values if litellm.drop_params is True: # drop the unsupported non-default values
@ -4840,6 +4880,7 @@ def get_optional_params_embeddings(
def get_optional_params( def get_optional_params(
# use the openai defaults # use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None, functions=None,
function_call=None, function_call=None,
temperature=None, temperature=None,
@ -4853,7 +4894,6 @@ def get_optional_params(
frequency_penalty=None, frequency_penalty=None,
logit_bias=None, logit_bias=None,
user=None, user=None,
model=None,
custom_llm_provider="", custom_llm_provider="",
response_format=None, response_format=None,
seed=None, seed=None,
@ -4882,7 +4922,7 @@ def get_optional_params(
passed_params[k] = v passed_params[k] = v
optional_params = {} optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]: if custom_llm_provider in common_auth_dict["providers"]:
@ -5156,41 +5196,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None optional_params = litellm.HuggingfaceConfig().map_openai_params(
if temperature is not None: non_default_params=non_default_params, optional_params=optional_params
if temperature == 0.0 or temperature == 0: )
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -5769,9 +5777,7 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
)
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai" model=model, custom_llm_provider="openai"
) )
@ -6152,7 +6158,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed", "seed",
] ]
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
return [ return [
"stream", "stream",
@ -9408,6 +9414,72 @@ def get_secret(
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
match oidc_provider:
case "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
case "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
case "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
case "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if actions_id_token_request_url is None or actions_id_token_request_token is None:
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text['value']
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
case _:
raise ValueError("Unsupported OIDC provider")
try: try:
if litellm.secret_manager_client is not None: if litellm.secret_manager_client is not None:
try: try:

View file

@ -1571,6 +1571,135 @@
"litellm_provider": "replicate", "litellm_provider": "replicate",
"mode": "chat" "mode": "chat"
}, },
"openrouter/microsoft/wizardlm-2-8x22b:nitro": {
"max_tokens": 65536,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-1.5": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.0000075,
"input_cost_per_image": 0.00265,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mixtral-8x22b-instruct": {
"max_tokens": 65536,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000065,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cohere/command-r-plus": {
"max_tokens": 128000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/databricks/dbrx-instruct": {
"max_tokens": 32768,
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/anthropic/claude-3-haiku": {
"max_tokens": 200000,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"input_cost_per_image": 0.0004,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/anthropic/claude-3-sonnet": {
"max_tokens": 200000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"input_cost_per_image": 0.0048,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/mistralai/mistral-large": {
"max_tokens": 32000,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/cognitivecomputations/dolphin-mixtral-8x7b": {
"max_tokens": 32769,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/google/gemini-pro-vision": {
"max_tokens": 45875,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000375,
"input_cost_per_image": 0.0025,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/fireworks/firellava-13b": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-8b-instruct:extended": {
"max_tokens": 16384,
"input_cost_per_token": 0.000000225,
"output_cost_per_token": 0.00000225,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct:nitro": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/openai/gpt-4-vision-preview": {
"max_tokens": 130000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"input_cost_per_image": 0.01445,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-3.5-turbo": { "openrouter/openai/gpt-3.5-turbo": {
"max_tokens": 4095, "max_tokens": 4095,
"input_cost_per_token": 0.0000015, "input_cost_per_token": 0.0000015,
@ -1621,14 +1750,14 @@
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"openrouter/google/palm-2-chat-bison": { "openrouter/google/palm-2-chat-bison": {
"max_tokens": 8000, "max_tokens": 25804,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/google/palm-2-codechat-bison": { "openrouter/google/palm-2-codechat-bison": {
"max_tokens": 8000, "max_tokens": 20070,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005, "output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
@ -1711,13 +1840,6 @@
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "openrouter",
"mode": "chat"
},
"j2-ultra": { "j2-ultra": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,
@ -3226,4 +3348,4 @@
"mode": "embedding" "mode": "embedding"
} }
} }

View file

@ -92,10 +92,12 @@ litellm_settings:
default_team_settings: default_team_settings:
- team_id: team-1 - team_id: team-1
success_callback: ["langfuse"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1 langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1
langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1 langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1
- team_id: team-2 - team_id: team-2
success_callback: ["langfuse"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.37.0" version = "1.37.4"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.37.0" version = "1.37.4"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -246,6 +246,33 @@ async def get_model_info_v2(session, key):
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
async def get_specific_model_info_v2(session, key, model_name):
url = "http://0.0.0.0:4000/v2/model/info?debug=True&model=" + model_name
print("running /model/info check for model=", model_name)
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
async with session.get(url, headers=headers) as response:
status = response.status
response_text = await response.text()
print("response from v2/model/info")
print(response_text)
print()
_json_response = await response.json()
print("JSON response from /v2/model/info?model=", model_name, _json_response)
_model_info = _json_response["data"]
assert len(_model_info) == 1, f"Expected 1 model, got {len(_model_info)}"
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return _model_info[0]
async def get_model_health(session, key, model_name): async def get_model_health(session, key, model_name):
url = "http://0.0.0.0:4000/health?model=" + model_name url = "http://0.0.0.0:4000/health?model=" + model_name
headers = { headers = {
@ -285,6 +312,11 @@ async def test_add_model_run_health():
model_name = f"azure-model-health-check-{model_id}" model_name = f"azure-model-health-check-{model_id}"
print("adding model", model_name) print("adding model", model_name)
await add_model_for_health_checking(session=session, model_id=model_id) await add_model_for_health_checking(session=session, model_id=model_id)
_old_model_info = await get_specific_model_info_v2(
session=session, key=key, model_name=model_name
)
print("model info before test", _old_model_info)
await asyncio.sleep(30) await asyncio.sleep(30)
print("calling /model/info") print("calling /model/info")
await get_model_info(session=session, key=key) await get_model_info(session=session, key=key)
@ -305,5 +337,28 @@ async def test_add_model_run_health():
_healthy_endpooint["model"] == "azure/chatgpt-v-2" _healthy_endpooint["model"] == "azure/chatgpt-v-2"
) # this is the model that got added ) # this is the model that got added
# assert httpx client is is unchanges
await asyncio.sleep(10)
_model_info_after_test = await get_specific_model_info_v2(
session=session, key=key, model_name=model_name
)
print("model info after test", _model_info_after_test)
old_openai_client = _old_model_info["openai_client"]
new_openai_client = _model_info_after_test["openai_client"]
print("old openai client", old_openai_client)
print("new openai client", new_openai_client)
"""
PROD TEST - This is extremly important
The OpenAI client used should be the same after 30 seconds
It is a serious bug if the openai client does not match here
"""
assert (
old_openai_client == new_openai_client
), "OpenAI client does not match for the same model after 30 seconds"
# cleanup # cleanup
await delete_model(session=session, model_id=model_id) await delete_model(session=session, model_id=model_id)

View file

@ -2,8 +2,17 @@
import React, { useState, useEffect, useRef } from "react"; import React, { useState, useEffect, useRef } from "react";
import { Button, TextInput, Grid, Col } from "@tremor/react"; import { Button, TextInput, Grid, Col } from "@tremor/react";
import { Card, Metric, Text, Title, Subtitle, Accordion, AccordionHeader, AccordionBody, } from "@tremor/react"; import {
import { CopyToClipboard } from 'react-copy-to-clipboard'; Card,
Metric,
Text,
Title,
Subtitle,
Accordion,
AccordionHeader,
AccordionBody,
} from "@tremor/react";
import { CopyToClipboard } from "react-copy-to-clipboard";
import { import {
Button as Button2, Button as Button2,
Modal, Modal,
@ -13,7 +22,11 @@ import {
Select, Select,
message, message,
} from "antd"; } from "antd";
import { keyCreateCall, slackBudgetAlertsHealthCheck, modelAvailableCall } from "./networking"; import {
keyCreateCall,
slackBudgetAlertsHealthCheck,
modelAvailableCall,
} from "./networking";
const { Option } = Select; const { Option } = Select;
@ -59,7 +72,11 @@ const CreateKey: React.FC<CreateKeyProps> = ({
} }
if (accessToken !== null) { if (accessToken !== null) {
const model_available = await modelAvailableCall(accessToken, userID, userRole); const model_available = await modelAvailableCall(
accessToken,
userID,
userRole
);
let available_model_names = model_available["data"].map( let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id (element: { id: string }) => element.id
); );
@ -70,12 +87,25 @@ const CreateKey: React.FC<CreateKeyProps> = ({
console.error("Error fetching user models:", error); console.error("Error fetching user models:", error);
} }
}; };
fetchUserModels(); fetchUserModels();
}, [accessToken, userID, userRole]); }, [accessToken, userID, userRole]);
const handleCreate = async (formValues: Record<string, any>) => { const handleCreate = async (formValues: Record<string, any>) => {
try { try {
const newKeyAlias = formValues?.key_alias ?? "";
const newKeyTeamId = formValues?.team_id ?? null;
const existingKeyAliases =
data
?.filter((k) => k.team_id === newKeyTeamId)
.map((k) => k.key_alias) ?? [];
if (existingKeyAliases.includes(newKeyAlias)) {
throw new Error(
`Key alias ${newKeyAlias} already exists for team with ID ${newKeyTeamId}, please provide another key alias`
);
}
message.info("Making API Call"); message.info("Making API Call");
setIsModalVisible(true); setIsModalVisible(true);
const response = await keyCreateCall(accessToken, userID, formValues); const response = await keyCreateCall(accessToken, userID, formValues);
@ -89,12 +119,13 @@ const CreateKey: React.FC<CreateKeyProps> = ({
localStorage.removeItem("userData" + userID); localStorage.removeItem("userData" + userID);
} catch (error) { } catch (error) {
console.error("Error creating the key:", error); console.error("Error creating the key:", error);
message.error(`Error creating the key: ${error}`, 20);
} }
}; };
const handleCopy = () => { const handleCopy = () => {
message.success('API Key copied to clipboard'); message.success("API Key copied to clipboard");
}; };
useEffect(() => { useEffect(() => {
let tempModelsToPick = []; let tempModelsToPick = [];
@ -119,7 +150,6 @@ const CreateKey: React.FC<CreateKeyProps> = ({
setModelsToPick(tempModelsToPick); setModelsToPick(tempModelsToPick);
}, [team, userModels]); }, [team, userModels]);
return ( return (
<div> <div>
@ -141,140 +171,164 @@ const CreateKey: React.FC<CreateKeyProps> = ({
wrapperCol={{ span: 16 }} wrapperCol={{ span: 16 }}
labelAlign="left" labelAlign="left"
> >
<> <>
<Form.Item <Form.Item
label="Key Name" label="Key Name"
name="key_alias" name="key_alias"
rules={[{ required: true, message: 'Please input a key name' }]} rules={[{ required: true, message: "Please input a key name" }]}
help="required" help="required"
> >
<TextInput placeholder="" /> <TextInput placeholder="" />
</Form.Item> </Form.Item>
<Form.Item <Form.Item
label="Team ID" label="Team ID"
name="team_id" name="team_id"
hidden={true} hidden={true}
initialValue={team ? team["team_id"] : null} initialValue={team ? team["team_id"] : null}
valuePropName="team_id" valuePropName="team_id"
className="mt-8"
>
<Input value={team ? team["team_alias"] : ""} disabled />
</Form.Item>
<Form.Item
label="Models"
name="models"
rules={[{ required: true, message: 'Please select a model' }]}
help="required"
>
<Select
mode="multiple"
placeholder="Select models"
style={{ width: "100%" }}
onChange={(values) => {
// Check if "All Team Models" is selected
const isAllTeamModelsSelected = values.includes("all-team-models");
// If "All Team Models" is selected, deselect all other models
if (isAllTeamModelsSelected) {
const newValues = ["all-team-models"];
// You can call the form's setFieldsValue method to update the value
form.setFieldsValue({ models: newValues });
}
}}
>
<Option key="all-team-models" value="all-team-models">
All Team Models
</Option>
{
modelsToPick.map((model: string) => (
(
<Option key={model} value={model}>
{model}
</Option>
)
))
}
</Select>
</Form.Item>
<Accordion className="mt-20 mb-8" >
<AccordionHeader>
<b>Optional Settings</b>
</AccordionHeader>
<AccordionBody>
<Form.Item
className="mt-8"
label="Max Budget (USD)"
name="max_budget"
help={`Budget cannot exceed team max budget: $${team?.max_budget !== null && team?.max_budget !== undefined ? team?.max_budget : 'unlimited'}`}
rules={[
{
validator: async (_, value) => {
if (value && team && team.max_budget !== null && value > team.max_budget) {
throw new Error(`Budget cannot exceed team max budget: $${team.max_budget}`);
}
},
},
]}
>
<InputNumber step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item
className="mt-8" className="mt-8"
label="Reset Budget" >
name="budget_duration" <Input value={team ? team["team_alias"] : ""} disabled />
help={`Team Reset Budget: ${team?.budget_duration !== null && team?.budget_duration !== undefined ? team?.budget_duration : 'None'}`} </Form.Item>
>
<Select defaultValue={null} placeholder="n/a"> <Form.Item
<Select.Option value="24h">daily</Select.Option> label="Models"
<Select.Option value="30d">monthly</Select.Option> name="models"
</Select> rules={[{ required: true, message: "Please select a model" }]}
</Form.Item> help="required"
<Form.Item >
className="mt-8" <Select
label="Tokens per minute Limit (TPM)" mode="multiple"
name="tpm_limit" placeholder="Select models"
help={`TPM cannot exceed team TPM limit: ${team?.tpm_limit !== null && team?.tpm_limit !== undefined ? team?.tpm_limit : 'unlimited'}`} style={{ width: "100%" }}
rules={[ onChange={(values) => {
{ // Check if "All Team Models" is selected
validator: async (_, value) => { const isAllTeamModelsSelected =
if (value && team && team.tpm_limit !== null && value > team.tpm_limit) { values.includes("all-team-models");
throw new Error(`TPM limit cannot exceed team TPM limit: ${team.tpm_limit}`);
} // If "All Team Models" is selected, deselect all other models
}, if (isAllTeamModelsSelected) {
}, const newValues = ["all-team-models"];
]} // You can call the form's setFieldsValue method to update the value
> form.setFieldsValue({ models: newValues });
<InputNumber step={1} width={400} /> }
</Form.Item> }}
<Form.Item >
className="mt-8" <Option key="all-team-models" value="all-team-models">
label="Requests per minute Limit (RPM)" All Team Models
name="rpm_limit" </Option>
help={`RPM cannot exceed team RPM limit: ${team?.rpm_limit !== null && team?.rpm_limit !== undefined ? team?.rpm_limit : 'unlimited'}`} {modelsToPick.map((model: string) => (
rules={[ <Option key={model} value={model}>
{ {model}
validator: async (_, value) => { </Option>
if (value && team && team.rpm_limit !== null && value > team.rpm_limit) { ))}
throw new Error(`RPM limit cannot exceed team RPM limit: ${team.rpm_limit}`); </Select>
} </Form.Item>
}, <Accordion className="mt-20 mb-8">
}, <AccordionHeader>
]} <b>Optional Settings</b>
> </AccordionHeader>
<InputNumber step={1} width={400} /> <AccordionBody>
</Form.Item> <Form.Item
<Form.Item label="Expire Key (eg: 30s, 30h, 30d)" name="duration" className="mt-8"> className="mt-8"
<TextInput placeholder="" /> label="Max Budget (USD)"
</Form.Item> name="max_budget"
<Form.Item label="Metadata" name="metadata"> help={`Budget cannot exceed team max budget: $${team?.max_budget !== null && team?.max_budget !== undefined ? team?.max_budget : "unlimited"}`}
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" /> rules={[
</Form.Item> {
validator: async (_, value) => {
if (
value &&
team &&
team.max_budget !== null &&
value > team.max_budget
) {
throw new Error(
`Budget cannot exceed team max budget: $${team.max_budget}`
);
}
},
},
]}
>
<InputNumber step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item
className="mt-8"
label="Reset Budget"
name="budget_duration"
help={`Team Reset Budget: ${team?.budget_duration !== null && team?.budget_duration !== undefined ? team?.budget_duration : "None"}`}
>
<Select defaultValue={null} placeholder="n/a">
<Select.Option value="24h">daily</Select.Option>
<Select.Option value="30d">monthly</Select.Option>
</Select>
</Form.Item>
<Form.Item
className="mt-8"
label="Tokens per minute Limit (TPM)"
name="tpm_limit"
help={`TPM cannot exceed team TPM limit: ${team?.tpm_limit !== null && team?.tpm_limit !== undefined ? team?.tpm_limit : "unlimited"}`}
rules={[
{
validator: async (_, value) => {
if (
value &&
team &&
team.tpm_limit !== null &&
value > team.tpm_limit
) {
throw new Error(
`TPM limit cannot exceed team TPM limit: ${team.tpm_limit}`
);
}
},
},
]}
>
<InputNumber step={1} width={400} />
</Form.Item>
<Form.Item
className="mt-8"
label="Requests per minute Limit (RPM)"
name="rpm_limit"
help={`RPM cannot exceed team RPM limit: ${team?.rpm_limit !== null && team?.rpm_limit !== undefined ? team?.rpm_limit : "unlimited"}`}
rules={[
{
validator: async (_, value) => {
if (
value &&
team &&
team.rpm_limit !== null &&
value > team.rpm_limit
) {
throw new Error(
`RPM limit cannot exceed team RPM limit: ${team.rpm_limit}`
);
}
},
},
]}
>
<InputNumber step={1} width={400} />
</Form.Item>
<Form.Item
label="Expire Key (eg: 30s, 30h, 30d)"
name="duration"
className="mt-8"
>
<TextInput placeholder="" />
</Form.Item>
<Form.Item label="Metadata" name="metadata">
<Input.TextArea
rows={4}
placeholder="Enter metadata as JSON"
/>
</Form.Item>
</AccordionBody>
</Accordion>
</>
</AccordionBody>
</Accordion>
</>
<div style={{ textAlign: "right", marginTop: "10px" }}> <div style={{ textAlign: "right", marginTop: "10px" }}>
<Button2 htmlType="submit">Create Key</Button2> <Button2 htmlType="submit">Create Key</Button2>
</div> </div>
@ -288,36 +342,45 @@ const CreateKey: React.FC<CreateKeyProps> = ({
footer={null} footer={null}
> >
<Grid numItems={1} className="gap-2 w-full"> <Grid numItems={1} className="gap-2 w-full">
<Title>Save your Key</Title>
<Title>Save your Key</Title> <Col numColSpan={1}>
<Col numColSpan={1}> <p>
<p> Please save this secret key somewhere safe and accessible. For
Please save this secret key somewhere safe and accessible. For security reasons, <b>you will not be able to view it again</b>{" "}
security reasons, <b>you will not be able to view it again</b>{" "} through your LiteLLM account. If you lose this secret key, you
through your LiteLLM account. If you lose this secret key, you will need to generate a new one.
will need to generate a new one. </p>
</p> </Col>
</Col> <Col numColSpan={1}>
<Col numColSpan={1}> {apiKey != null ? (
{apiKey != null ? ( <div>
<div>
<Text className="mt-3">API Key:</Text> <Text className="mt-3">API Key:</Text>
<div style={{ background: '#f8f8f8', padding: '10px', borderRadius: '5px', marginBottom: '10px' }}> <div
<pre style={{ wordWrap: 'break-word', whiteSpace: 'normal' }}>{apiKey}</pre> style={{
</div> background: "#f8f8f8",
padding: "10px",
<CopyToClipboard text={apiKey} onCopy={handleCopy}> borderRadius: "5px",
marginBottom: "10px",
}}
>
<pre
style={{ wordWrap: "break-word", whiteSpace: "normal" }}
>
{apiKey}
</pre>
</div>
<CopyToClipboard text={apiKey} onCopy={handleCopy}>
<Button className="mt-3">Copy API Key</Button> <Button className="mt-3">Copy API Key</Button>
</CopyToClipboard> </CopyToClipboard>
{/* <Button className="mt-3" onClick={sendSlackAlert}> {/* <Button className="mt-3" onClick={sendSlackAlert}>
Test Key Test Key
</Button> */} </Button> */}
</div> </div>
) : ( ) : (
<Text>Key being created, this might take 30s</Text> <Text>Key being created, this might take 30s</Text>
)} )}
</Col> </Col>
</Grid> </Grid>
</Modal> </Modal>
)} )}

View file

@ -2,7 +2,13 @@ import React, { useState, useEffect } from "react";
import Link from "next/link"; import Link from "next/link";
import { Typography } from "antd"; import { Typography } from "antd";
import { teamDeleteCall, teamUpdateCall, teamInfoCall } from "./networking"; import { teamDeleteCall, teamUpdateCall, teamInfoCall } from "./networking";
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon } from "@heroicons/react/outline"; import {
InformationCircleIcon,
PencilAltIcon,
PencilIcon,
StatusOnlineIcon,
TrashIcon,
} from "@heroicons/react/outline";
import { import {
Button as Button2, Button as Button2,
Modal, Modal,
@ -46,8 +52,12 @@ interface EditTeamModalProps {
onSubmit: (data: FormData) => void; // Assuming FormData is the type of data to be submitted onSubmit: (data: FormData) => void; // Assuming FormData is the type of data to be submitted
} }
import {
import { teamCreateCall, teamMemberAddCall, Member, modelAvailableCall } from "./networking"; teamCreateCall,
teamMemberAddCall,
Member,
modelAvailableCall,
} from "./networking";
const Team: React.FC<TeamProps> = ({ const Team: React.FC<TeamProps> = ({
teams, teams,
@ -63,7 +73,6 @@ const Team: React.FC<TeamProps> = ({
const [value, setValue] = useState(""); const [value, setValue] = useState("");
const [editModalVisible, setEditModalVisible] = useState(false); const [editModalVisible, setEditModalVisible] = useState(false);
const [selectedTeam, setSelectedTeam] = useState<null | any>( const [selectedTeam, setSelectedTeam] = useState<null | any>(
teams ? teams[0] : null teams ? teams[0] : null
); );
@ -76,127 +85,125 @@ const Team: React.FC<TeamProps> = ({
// store team info as {"team_id": team_info_object} // store team info as {"team_id": team_info_object}
const [perTeamInfo, setPerTeamInfo] = useState<Record<string, any>>({}); const [perTeamInfo, setPerTeamInfo] = useState<Record<string, any>>({});
const EditTeamModal: React.FC<EditTeamModalProps> = ({
visible,
onCancel,
team,
onSubmit,
}) => {
const [form] = Form.useForm();
const EditTeamModal: React.FC<EditTeamModalProps> = ({ visible, onCancel, team, onSubmit }) => { const handleOk = () => {
const [form] = Form.useForm(); form
.validateFields()
.then((values) => {
const updatedValues = { ...values, team_id: team.team_id };
onSubmit(updatedValues);
form.resetFields();
})
.catch((error) => {
console.error("Validation failed:", error);
});
};
const handleOk = () => { return (
form
.validateFields()
.then((values) => {
const updatedValues = {...values, team_id: team.team_id};
onSubmit(updatedValues);
form.resetFields();
})
.catch((error) => {
console.error("Validation failed:", error);
});
};
return (
<Modal <Modal
title="Edit Team" title="Edit Team"
visible={visible} visible={visible}
width={800} width={800}
footer={null} footer={null}
onOk={handleOk} onOk={handleOk}
onCancel={onCancel} onCancel={onCancel}
>
<Form
form={form}
onFinish={handleEditSubmit}
initialValues={team} // Pass initial values here
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
> >
<> <Form
<Form.Item form={form}
label="Team Name" onFinish={handleEditSubmit}
name="team_alias" initialValues={team} // Pass initial values here
rules={[{ required: true, message: 'Please input a team name' }]} labelCol={{ span: 8 }}
> wrapperCol={{ span: 16 }}
<Input /> labelAlign="left"
</Form.Item> >
<Form.Item label="Models" name="models"> <>
<Select2 <Form.Item
mode="multiple" label="Team Name"
placeholder="Select models" name="team_alias"
style={{ width: "100%" }} rules={[{ required: true, message: "Please input a team name" }]}
> >
<Select2.Option key="all-proxy-models" value="all-proxy-models"> <Input />
{"All Proxy Models"} </Form.Item>
</Select2.Option> <Form.Item label="Models" name="models">
{userModels && userModels.map((model) => ( <Select2
<Select2.Option key={model} value={model}> mode="multiple"
{model} placeholder="Select models"
</Select2.Option> style={{ width: "100%" }}
))} >
<Select2.Option key="all-proxy-models" value="all-proxy-models">
</Select2> {"All Proxy Models"}
</Form.Item> </Select2.Option>
<Form.Item label="Max Budget (USD)" name="max_budget"> {userModels &&
<InputNumber step={0.01} precision={2} width={200} /> userModels.map((model) => (
</Form.Item> <Select2.Option key={model} value={model}>
<Form.Item {model}
label="Tokens per minute Limit (TPM)" </Select2.Option>
name="tpm_limit" ))}
> </Select2>
<InputNumber step={1} width={400} /> </Form.Item>
</Form.Item> <Form.Item label="Max Budget (USD)" name="max_budget">
<Form.Item <InputNumber step={0.01} precision={2} width={200} />
label="Requests per minute Limit (RPM)" </Form.Item>
name="rpm_limit" <Form.Item label="Tokens per minute Limit (TPM)" name="tpm_limit">
> <InputNumber step={1} width={400} />
<InputNumber step={1} width={400} /> </Form.Item>
</Form.Item> <Form.Item label="Requests per minute Limit (RPM)" name="rpm_limit">
<Form.Item <InputNumber step={1} width={400} />
label="Requests per minute Limit (RPM)" </Form.Item>
name="team_id" <Form.Item
hidden={true} label="Requests per minute Limit (RPM)"
></Form.Item> name="team_id"
</> hidden={true}
<div style={{ textAlign: "right", marginTop: "10px" }}> ></Form.Item>
<Button2 htmlType="submit">Edit Team</Button2> </>
</div> <div style={{ textAlign: "right", marginTop: "10px" }}>
</Form> <Button2 htmlType="submit">Edit Team</Button2>
</Modal> </div>
); </Form>
}; </Modal>
const handleEditClick = (team: any) => {
setSelectedTeam(team);
setEditModalVisible(true);
};
const handleEditCancel = () => {
setEditModalVisible(false);
setSelectedTeam(null);
};
const handleEditSubmit = async (formValues: Record<string, any>) => {
// Call API to update team with teamId and values
const teamId = formValues.team_id; // get team_id
console.log("handleEditSubmit:", formValues);
if (accessToken == null) {
return;
}
let newTeamValues = await teamUpdateCall(accessToken, formValues);
// Update the teams state with the updated team data
if (teams) {
const updatedTeams = teams.map((team) =>
team.team_id === teamId ? newTeamValues.data : team
); );
setTeams(updatedTeams); };
}
message.success("Team updated successfully");
setEditModalVisible(false); const handleEditClick = (team: any) => {
setSelectedTeam(null); setSelectedTeam(team);
}; setEditModalVisible(true);
};
const handleEditCancel = () => {
setEditModalVisible(false);
setSelectedTeam(null);
};
const handleEditSubmit = async (formValues: Record<string, any>) => {
// Call API to update team with teamId and values
const teamId = formValues.team_id; // get team_id
console.log("handleEditSubmit:", formValues);
if (accessToken == null) {
return;
}
let newTeamValues = await teamUpdateCall(accessToken, formValues);
// Update the teams state with the updated team data
if (teams) {
const updatedTeams = teams.map((team) =>
team.team_id === teamId ? newTeamValues.data : team
);
setTeams(updatedTeams);
}
message.success("Team updated successfully");
setEditModalVisible(false);
setSelectedTeam(null);
};
const handleOk = () => { const handleOk = () => {
setIsTeamModalVisible(false); setIsTeamModalVisible(false);
@ -224,9 +231,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
setIsDeleteModalOpen(true); setIsDeleteModalOpen(true);
}; };
const confirmDelete = async () => { const confirmDelete = async () => {
if (teamToDelete == null || teams == null || accessToken == null) { if (teamToDelete == null || teams == null || accessToken == null) {
return; return;
@ -235,7 +239,9 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
try { try {
await teamDeleteCall(accessToken, teamToDelete); await teamDeleteCall(accessToken, teamToDelete);
// Successfully completed the deletion. Update the state to trigger a rerender. // Successfully completed the deletion. Update the state to trigger a rerender.
const filteredData = teams.filter((item) => item.team_id !== teamToDelete); const filteredData = teams.filter(
(item) => item.team_id !== teamToDelete
);
setTeams(filteredData); setTeams(filteredData);
} catch (error) { } catch (error) {
console.error("Error deleting the team:", error); console.error("Error deleting the team:", error);
@ -253,8 +259,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
setTeamToDelete(null); setTeamToDelete(null);
}; };
useEffect(() => { useEffect(() => {
const fetchUserModels = async () => { const fetchUserModels = async () => {
try { try {
@ -263,7 +267,11 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
} }
if (accessToken !== null) { if (accessToken !== null) {
const model_available = await modelAvailableCall(accessToken, userID, userRole); const model_available = await modelAvailableCall(
accessToken,
userID,
userRole
);
let available_model_names = model_available["data"].map( let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id (element: { id: string }) => element.id
); );
@ -275,7 +283,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
} }
}; };
const fetchTeamInfo = async () => { const fetchTeamInfo = async () => {
try { try {
if (userID === null || userRole === null || accessToken === null) { if (userID === null || userRole === null || accessToken === null) {
@ -288,22 +295,21 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
console.log("fetching team info:"); console.log("fetching team info:");
let _team_id_to_info: Record<string, any> = {}; let _team_id_to_info: Record<string, any> = {};
for (let i = 0; i < teams?.length; i++) { for (let i = 0; i < teams?.length; i++) {
let _team_id = teams[i].team_id; let _team_id = teams[i].team_id;
const teamInfo = await teamInfoCall(accessToken, _team_id); const teamInfo = await teamInfoCall(accessToken, _team_id);
console.log("teamInfo response:", teamInfo); console.log("teamInfo response:", teamInfo);
if (teamInfo !== null) { if (teamInfo !== null) {
_team_id_to_info = {..._team_id_to_info, [_team_id]: teamInfo}; _team_id_to_info = { ..._team_id_to_info, [_team_id]: teamInfo };
} }
} }
setPerTeamInfo(_team_id_to_info); setPerTeamInfo(_team_id_to_info);
} catch (error) { } catch (error) {
console.error("Error fetching team info:", error); console.error("Error fetching team info:", error);
} }
}; };
fetchUserModels(); fetchUserModels();
fetchTeamInfo(); fetchTeamInfo();
}, [accessToken, userID, userRole, teams]); }, [accessToken, userID, userRole, teams]);
@ -311,6 +317,15 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
const handleCreate = async (formValues: Record<string, any>) => { const handleCreate = async (formValues: Record<string, any>) => {
try { try {
if (accessToken != null) { if (accessToken != null) {
const newTeamAlias = formValues?.team_alias;
const existingTeamAliases = teams?.map((t) => t.team_alias) ?? [];
if (existingTeamAliases.includes(newTeamAlias)) {
throw new Error(
`Team alias ${newTeamAlias} already exists, please pick another alias`
);
}
message.info("Creating Team"); message.info("Creating Team");
const response: any = await teamCreateCall(accessToken, formValues); const response: any = await teamCreateCall(accessToken, formValues);
if (teams !== null) { if (teams !== null) {
@ -364,7 +379,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
console.error("Error creating the team:", error); console.error("Error creating the team:", error);
} }
}; };
console.log(`received teams ${teams}`); console.log(`received teams ${JSON.stringify(teams)}`);
return ( return (
<div className="w-full mx-4"> <div className="w-full mx-4">
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2"> <Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
@ -387,55 +402,124 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
{teams && teams.length > 0 {teams && teams.length > 0
? teams.map((team: any) => ( ? teams.map((team: any) => (
<TableRow key={team.team_id}> <TableRow key={team.team_id}>
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>{team["team_alias"]}</TableCell> <TableCell
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>{team["spend"]}</TableCell> style={{
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}> maxWidth: "4px",
whiteSpace: "pre-wrap",
overflow: "hidden",
}}
>
{team["team_alias"]}
</TableCell>
<TableCell
style={{
maxWidth: "4px",
whiteSpace: "pre-wrap",
overflow: "hidden",
}}
>
{team["spend"]}
</TableCell>
<TableCell
style={{
maxWidth: "4px",
whiteSpace: "pre-wrap",
overflow: "hidden",
}}
>
{team["max_budget"] ? team["max_budget"] : "No limit"} {team["max_budget"] ? team["max_budget"] : "No limit"}
</TableCell> </TableCell>
<TableCell style={{ maxWidth: "8-x", whiteSpace: "pre-wrap", overflow: "hidden" }}> <TableCell
style={{
maxWidth: "8-x",
whiteSpace: "pre-wrap",
overflow: "hidden",
}}
>
{Array.isArray(team.models) ? ( {Array.isArray(team.models) ? (
<div style={{ display: "flex", flexDirection: "column" }}> <div
style={{
display: "flex",
flexDirection: "column",
}}
>
{team.models.length === 0 ? ( {team.models.length === 0 ? (
<Badge size={"xs"} className="mb-1" color="red"> <Badge size={"xs"} className="mb-1" color="red">
<Text>All Proxy Models</Text> <Text>All Proxy Models</Text>
</Badge> </Badge>
) : ( ) : (
team.models.map((model: string, index: number) => ( team.models.map(
model === "all-proxy-models" ? ( (model: string, index: number) =>
<Badge key={index} size={"xs"} className="mb-1" color="red"> model === "all-proxy-models" ? (
<Text>All Proxy Models</Text> <Badge
</Badge> key={index}
) : ( size={"xs"}
<Badge key={index} size={"xs"} className="mb-1" color="blue"> className="mb-1"
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text> color="red"
</Badge> >
) <Text>All Proxy Models</Text>
)) </Badge>
) : (
<Badge
key={index}
size={"xs"}
className="mb-1"
color="blue"
>
<Text>
{model.length > 30
? `${model.slice(0, 30)}...`
: model}
</Text>
</Badge>
)
)
)} )}
</div> </div>
) : null} ) : null}
</TableCell> </TableCell>
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}> <TableCell
style={{
maxWidth: "4px",
whiteSpace: "pre-wrap",
overflow: "hidden",
}}
>
<Text> <Text>
TPM:{" "} TPM: {team.tpm_limit ? team.tpm_limit : "Unlimited"}{" "}
{team.tpm_limit ? team.tpm_limit : "Unlimited"}{" "}
<br></br>RPM:{" "} <br></br>RPM:{" "}
{team.rpm_limit ? team.rpm_limit : "Unlimited"} {team.rpm_limit ? team.rpm_limit : "Unlimited"}
</Text> </Text>
</TableCell> </TableCell>
<TableCell> <TableCell>
<Text>{perTeamInfo && team.team_id && perTeamInfo[team.team_id] && perTeamInfo[team.team_id].keys && perTeamInfo[team.team_id].keys.length} Keys</Text> <Text>
<Text>{perTeamInfo && team.team_id && perTeamInfo[team.team_id] && perTeamInfo[team.team_id].team_info && perTeamInfo[team.team_id].team_info.members_with_roles && perTeamInfo[team.team_id].team_info.members_with_roles.length} Members</Text> {perTeamInfo &&
team.team_id &&
perTeamInfo[team.team_id] &&
perTeamInfo[team.team_id].keys &&
perTeamInfo[team.team_id].keys.length}{" "}
Keys
</Text>
<Text>
{perTeamInfo &&
team.team_id &&
perTeamInfo[team.team_id] &&
perTeamInfo[team.team_id].team_info &&
perTeamInfo[team.team_id].team_info
.members_with_roles &&
perTeamInfo[team.team_id].team_info
.members_with_roles.length}{" "}
Members
</Text>
</TableCell> </TableCell>
<TableCell> <TableCell>
<Icon <Icon
icon={PencilAltIcon} icon={PencilAltIcon}
size="sm" size="sm"
onClick={() => handleEditClick(team)} onClick={() => handleEditClick(team)}
/> />
<Icon <Icon
onClick={() => handleDelete(team.team_id)} onClick={() => handleDelete(team.team_id)}
icon={TrashIcon} icon={TrashIcon}
size="sm" size="sm"
@ -481,7 +565,11 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</div> </div>
</div> </div>
<div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse"> <div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse">
<Button onClick={confirmDelete} color="red" className="ml-2"> <Button
onClick={confirmDelete}
color="red"
className="ml-2"
>
Delete Delete
</Button> </Button>
<Button onClick={cancelDelete}>Cancel</Button> <Button onClick={cancelDelete}>Cancel</Button>
@ -515,10 +603,12 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
labelAlign="left" labelAlign="left"
> >
<> <>
<Form.Item <Form.Item
label="Team Name" label="Team Name"
name="team_alias" name="team_alias"
rules={[{ required: true, message: 'Please input a team name' }]} rules={[
{ required: true, message: "Please input a team name" },
]}
> >
<TextInput placeholder="" /> <TextInput placeholder="" />
</Form.Item> </Form.Item>
@ -528,7 +618,10 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
placeholder="Select models" placeholder="Select models"
style={{ width: "100%" }} style={{ width: "100%" }}
> >
<Select2.Option key="all-proxy-models" value="all-proxy-models"> <Select2.Option
key="all-proxy-models"
value="all-proxy-models"
>
All Proxy Models All Proxy Models
</Select2.Option> </Select2.Option>
{userModels.map((model) => ( {userModels.map((model) => (
@ -606,8 +699,8 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
{member["user_email"] {member["user_email"]
? member["user_email"] ? member["user_email"]
: member["user_id"] : member["user_id"]
? member["user_id"] ? member["user_id"]
: null} : null}
</TableCell> </TableCell>
<TableCell>{member["role"]}</TableCell> <TableCell>{member["role"]}</TableCell>
</TableRow> </TableRow>
@ -618,13 +711,13 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</Table> </Table>
</Card> </Card>
{selectedTeam && ( {selectedTeam && (
<EditTeamModal <EditTeamModal
visible={editModalVisible} visible={editModalVisible}
onCancel={handleEditCancel} onCancel={handleEditCancel}
team={selectedTeam} team={selectedTeam}
onSubmit={handleEditSubmit} onSubmit={handleEditSubmit}
/> />
)} )}
</Col> </Col>
<Col numColSpan={1}> <Col numColSpan={1}>
<Button <Button