mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
d0493248f4
37 changed files with 2206 additions and 551 deletions
|
@ -198,6 +198,7 @@ jobs:
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
|
-e AUTO_INFER_REGION=True \
|
||||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||||
|
|
|
@ -17,6 +17,14 @@ This covers:
|
||||||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
||||||
|
|
||||||
|
|
||||||
|
## [COMING SOON] AWS Marketplace Support
|
||||||
|
|
||||||
|
Deploy managed LiteLLM Proxy within your VPC.
|
||||||
|
|
||||||
|
Includes all enterprise features.
|
||||||
|
|
||||||
|
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
## Frequently Asked Questions
|
## Frequently Asked Questions
|
||||||
|
|
||||||
### What topics does Professional support cover and what SLAs do you offer?
|
### What topics does Professional support cover and what SLAs do you offer?
|
||||||
|
|
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||||
|
|
||||||
|
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -40,9 +45,58 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: wizard-coder
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "wizard-coder",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||||
|
|
||||||
|
Append `conversational` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/conversational/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://my-endpoint.huggingface.cloud"
|
api_base="https://my-endpoint.huggingface.cloud"
|
||||||
)
|
)
|
||||||
|
@ -62,7 +116,123 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: blenderbot
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "blenderbot",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="classification" label="Text Classification">
|
||||||
|
|
||||||
|
Append `text-classification` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-classification/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# [OPTIONAL] set env var
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
|
|
||||||
|
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||||
|
|
||||||
|
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "bert-classifier",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||||
|
|
||||||
|
Append `text-generation` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-generation/<model-name>`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/roneneldan/TinyStories-3M",
|
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
)
|
)
|
||||||
|
|
247
docs/my-website/docs/providers/predibase.md
Normal file
247
docs/my-website/docs/providers/predibase.md
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 🆕 Predibase
|
||||||
|
|
||||||
|
LiteLLM supports all models on Predibase
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
### API KEYS
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
|
||||||
|
|
||||||
|
# predibase llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="llama-3",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama-3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Advanced Usage - Prompt Formatting
|
||||||
|
|
||||||
|
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
|
||||||
|
|
||||||
|
To apply a custom prompt template:
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
|
||||||
|
# Create your own custom prompt template
|
||||||
|
litellm.register_prompt_template(
|
||||||
|
model="togethercomputer/LLaMA-2-7B-32K",
|
||||||
|
initial_prompt_value="You are a good assistant" # [OPTIONAL]
|
||||||
|
roles={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||||
|
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "[INST] ", # [OPTIONAL]
|
||||||
|
"post_message": " [/INST]" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "\n" # [OPTIONAL]
|
||||||
|
"post_message": "\n" # [OPTIONAL]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
def predibase_custom_model():
|
||||||
|
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
|
||||||
|
response = completion(model=model, messages=messages)
|
||||||
|
print(response['choices'][0]['message']['content'])
|
||||||
|
return response
|
||||||
|
|
||||||
|
predibase_custom_model()
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Model-specific parameters
|
||||||
|
model_list:
|
||||||
|
- model_name: mistral-7b # model alias
|
||||||
|
litellm_params: # actual params for litellm.completion()
|
||||||
|
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
initial_prompt_value: "\n"
|
||||||
|
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||||
|
final_prompt_value: "\n"
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
max_tokens: 4096
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Passing additional params - max_tokens, temperature
|
||||||
|
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibae llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.5
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
max_tokens: 20
|
||||||
|
temperature: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Passings Predibase specific params - adapter_id, adapter_source,
|
||||||
|
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
|
||||||
|
|
||||||
|
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibase llama3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
adapter_id="my_repo/3",
|
||||||
|
adapter_soruce="pbase",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
adapter_id: my_repo/3
|
||||||
|
adapter_source: pbase
|
||||||
|
```
|
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Triton Inference Server
|
||||||
|
|
||||||
|
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/<your-triton-model>",
|
||||||
|
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=["hello from litellm"],
|
||||||
|
model="my-triton-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
|
@ -132,6 +132,9 @@ const sidebars = {
|
||||||
"providers/cohere",
|
"providers/cohere",
|
||||||
"providers/anyscale",
|
"providers/anyscale",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
|
"providers/watsonx",
|
||||||
|
"providers/predibase",
|
||||||
|
"providers/triton-inference-server",
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/perplexity",
|
"providers/perplexity",
|
||||||
"providers/groq",
|
"providers/groq",
|
||||||
|
@ -151,7 +154,7 @@ const sidebars = {
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/custom_openai_proxy",
|
"providers/custom_openai_proxy",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
"providers/watsonx",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
### Hide pydantic namespace conflict warnings globally ###
|
### Hide pydantic namespace conflict warnings globally ###
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
|
@ -537,6 +538,7 @@ provider_list: List = [
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
"triton",
|
||||||
"predibase",
|
"predibase",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
|
@ -262,7 +262,23 @@ class LangFuseLogger:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tags = []
|
tags = []
|
||||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
try:
|
||||||
|
metadata = copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
) # Avoid modifying the original metadata
|
||||||
|
except:
|
||||||
|
new_metadata = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if (
|
||||||
|
isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
or isinstance(value, str)
|
||||||
|
or isinstance(value, int)
|
||||||
|
or isinstance(value, float)
|
||||||
|
):
|
||||||
|
new_metadata[key] = copy.deepcopy(value)
|
||||||
|
metadata = new_metadata
|
||||||
|
|
||||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
|
@ -346,6 +362,7 @@ class LangFuseLogger:
|
||||||
"version": clean_metadata.pop(
|
"version": clean_metadata.pop(
|
||||||
"trace_version", clean_metadata.get("version", None)
|
"trace_version", clean_metadata.get("version", None)
|
||||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||||
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
for key in list(
|
for key in list(
|
||||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
|
|
@ -4,7 +4,6 @@ from datetime import datetime, timezone
|
||||||
import traceback
|
import traceback
|
||||||
import dotenv
|
import dotenv
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
|
@ -18,13 +17,33 @@ def parse_usage(usage):
|
||||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def parse_tool_calls(tool_calls):
|
||||||
|
if tool_calls is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tool_call(tool_call):
|
||||||
|
|
||||||
|
serialized = {
|
||||||
|
"type": tool_call.type,
|
||||||
|
"id": tool_call.id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
|
||||||
def parse_messages(input):
|
def parse_messages(input):
|
||||||
|
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clean_message(message):
|
def clean_message(message):
|
||||||
# if is strin, return as is
|
# if is string, return as is
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@ -38,9 +57,7 @@ def parse_messages(input):
|
||||||
|
|
||||||
# Only add tool_calls and function_call to res if they are set
|
# Only add tool_calls and function_call to res if they are set
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
serialized["tool_calls"] = message.get("tool_calls")
|
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||||
if message.get("function_call"):
|
|
||||||
serialized["function_call"] = message.get("function_call")
|
|
||||||
|
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
@ -93,8 +110,13 @@ class LunaryLogger:
|
||||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
optional_params = kwargs.get("optional_params", {})
|
||||||
metadata = litellm_params.get("metadata", {}) or {}
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
|
|
||||||
|
if optional_params:
|
||||||
|
# merge into extra
|
||||||
|
extra = {**extra, **optional_params}
|
||||||
|
|
||||||
tags = litellm_params.pop("tags", None) or []
|
tags = litellm_params.pop("tags", None) or []
|
||||||
|
|
||||||
if extra:
|
if extra:
|
||||||
|
@ -104,7 +126,7 @@ class LunaryLogger:
|
||||||
|
|
||||||
# keep only serializable types
|
# keep only serializable types
|
||||||
for param, value in extra.items():
|
for param, value in extra.items():
|
||||||
if not isinstance(value, (str, int, bool, float)):
|
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||||
try:
|
try:
|
||||||
extra[param] = str(value)
|
extra[param] = str(value)
|
||||||
except:
|
except:
|
||||||
|
@ -140,7 +162,7 @@ class LunaryLogger:
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
runtime="litellm",
|
runtime="litellm",
|
||||||
tags=tags,
|
tags=tags,
|
||||||
extra=extra,
|
params=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lunary_client.track_event(
|
self.lunary_client.track_event(
|
||||||
|
|
|
@ -8,6 +8,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
get_secret,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional, BinaryIO
|
from typing import Callable, Optional, BinaryIO
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
|
@ -16,6 +17,7 @@ import httpx # type: ignore
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
import uuid
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
return azure_client_params
|
return azure_client_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
|
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||||
|
|
||||||
|
if azure_client_id is None or azure_tenant is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||||
|
)
|
||||||
|
|
||||||
|
oidc_token = get_secret(azure_ad_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
)
|
||||||
|
|
||||||
|
req_token = httpx.post(
|
||||||
|
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||||
|
data={
|
||||||
|
"client_id": azure_client_id,
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"scope": "https://cognitiveservices.azure.com/.default",
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"client_assertion": oidc_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if req_token.status_code != 200:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=req_token.status_code,
|
||||||
|
message=req_token.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||||
|
|
||||||
|
if possible_azure_ad_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token not returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
return possible_azure_ad_token
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["api-key"] = api_key
|
headers["api-key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
|
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
|
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation == True:
|
||||||
|
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if max_retries is not None:
|
if max_retries is not None:
|
||||||
|
|
|
@ -551,6 +551,7 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
|
@ -567,6 +568,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
|
@ -582,6 +584,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### SET REGION NAME
|
### SET REGION NAME
|
||||||
|
@ -620,7 +623,38 @@ def init_bedrock_client(
|
||||||
config = boto3.session.Config()
|
config = boto3.session.Config()
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_role_name is not None and aws_session_name is not None:
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
# use sts if role name passed in
|
# use sts if role name passed in
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -755,6 +789,7 @@ def completion(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = optional_params.pop("aws_bedrock_client", None)
|
client = optional_params.pop("aws_bedrock_client", None)
|
||||||
|
@ -769,6 +804,7 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
@ -1291,6 +1327,7 @@ def embedding(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1298,6 +1335,7 @@ def embedding(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
|
@ -1380,6 +1418,7 @@ def image_generation(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1387,6 +1426,7 @@ def image_generation(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable, Dict, List, Any
|
from typing import Callable, Dict, List, Any, Literal
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
hf_task_list = [
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_tasks = Literal[
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceConfig:
|
class HuggingfaceConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[hf_tasks] = (
|
||||||
|
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
)
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
decoder_input_details: Optional[bool] = None
|
decoder_input_details: Optional[bool] = None
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
"""
|
"""
|
||||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_for_model(model):
|
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
return model.split("/")[0] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference"
|
||||||
elif model in conversational_models:
|
elif model in conversational_models:
|
||||||
return "conversational"
|
return "conversational"
|
||||||
elif "roneneldan/TinyStories" in model:
|
elif "roneneldan/TinyStories" in model:
|
||||||
return None
|
return "text-generation"
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference" # default to tgi
|
return "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
||||||
self,
|
self,
|
||||||
completion_response,
|
completion_response,
|
||||||
model_response,
|
model_response,
|
||||||
task,
|
task: hf_tasks,
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
input_text,
|
input_text,
|
||||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response["choices"].extend(choices_list)
|
model_response["choices"].extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
task = get_hf_task_for_model(model)
|
task = get_hf_task_for_model(model)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
print_verbose(f"{model}, {task}")
|
print_verbose(f"{model}, {task}")
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
}
|
||||||
"stream": ( # type: ignore
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
else False
|
else False
|
||||||
),
|
)
|
||||||
}
|
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
|
||||||
isinstance(completion_response, dict)
|
isinstance(completion_response, dict)
|
||||||
and "error" in completion_response
|
and "error" in completion_response
|
||||||
):
|
):
|
||||||
print_verbose(f"completion error: {completion_response['error']}")
|
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||||
print_verbose(f"response.status_code: {response.status_code}")
|
print_verbose(f"response.status_code: {response.status_code}")
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
message=completion_response["error"],
|
message=completion_response["error"], # type: ignore
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
return self.convert_to_model_response_object(
|
return self.convert_to_model_response_object(
|
||||||
|
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
task: str,
|
task: hf_tasks,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List
|
||||||
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TritonError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class TritonChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_base: str,
|
||||||
|
logging_obj=None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TritonError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
_text_response = response.text
|
||||||
|
|
||||||
|
logging_obj.post_call(original_response=_text_response)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
_outputs = _json_response["outputs"]
|
||||||
|
_output_data = _outputs[0]["data"]
|
||||||
|
_embedding_output = {
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": _output_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
|
model_response.data = [_embedding_output]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
api_base: str,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aembedding=None,
|
||||||
|
):
|
||||||
|
data_for_triton = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": input,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
|
||||||
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input="",
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": curl_string,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(
|
||||||
|
data=data_for_triton,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||||
|
)
|
|
@ -419,6 +419,7 @@ def completion(
|
||||||
from google.protobuf.struct_pb2 import Value # type: ignore
|
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
import proto # type: ignore
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -605,9 +606,21 @@ def completion(
|
||||||
):
|
):
|
||||||
function_call = response.candidates[0].content.parts[0].function_call
|
function_call = response.candidates[0].content.parts[0].function_call
|
||||||
args_dict = {}
|
args_dict = {}
|
||||||
for k, v in function_call.args.items():
|
|
||||||
args_dict[k] = v
|
# Check if it's a RepeatedComposite instance
|
||||||
args_str = json.dumps(args_dict)
|
for key, val in function_call.args.items():
|
||||||
|
if isinstance(
|
||||||
|
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||||
|
):
|
||||||
|
# If so, convert to list
|
||||||
|
args_dict[key] = [v for v in val]
|
||||||
|
else:
|
||||||
|
args_dict[key] = val
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=422, message=str(e))
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
|
@ -810,6 +823,8 @@ def completion(
|
||||||
setattr(model_response, "usage", usage)
|
setattr(model_response, "usage", usage)
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if isinstance(e, VertexAIError):
|
||||||
|
raise e
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
|
@ -47,6 +46,7 @@ from .llms import (
|
||||||
ai21,
|
ai21,
|
||||||
sagemaker,
|
sagemaker,
|
||||||
bedrock,
|
bedrock,
|
||||||
|
triton,
|
||||||
huggingface_restapi,
|
huggingface_restapi,
|
||||||
replicate,
|
replicate,
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
|
@ -75,6 +75,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -112,6 +113,7 @@ azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
|
triton_chat_completions = TritonChatCompletion()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -662,6 +664,7 @@ def completion(
|
||||||
"region_name",
|
"region_name",
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
|
@ -2621,6 +2624,7 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "voyage"
|
or custom_llm_provider == "voyage"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "triton"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "openrouter"
|
or custom_llm_provider == "openrouter"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
|
@ -2954,23 +2958,43 @@ def embedding(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "triton":
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_base is required for triton. Please pass `api_base`"
|
||||||
|
)
|
||||||
|
response = triton_chat_completions.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
optional_params=optional_params,
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
or optional_params.pop("vertex_ai_project", None)
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
or litellm.vertex_project
|
or litellm.vertex_project
|
||||||
or get_secret("VERTEXAI_PROJECT")
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
or get_secret("VERTEX_PROJECT")
|
||||||
)
|
)
|
||||||
vertex_ai_location = (
|
vertex_ai_location = (
|
||||||
optional_params.pop("vertex_location", None)
|
optional_params.pop("vertex_location", None)
|
||||||
or optional_params.pop("vertex_ai_location", None)
|
or optional_params.pop("vertex_ai_location", None)
|
||||||
or litellm.vertex_location
|
or litellm.vertex_location
|
||||||
or get_secret("VERTEXAI_LOCATION")
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
or get_secret("VERTEX_LOCATION")
|
||||||
)
|
)
|
||||||
vertex_credentials = (
|
vertex_credentials = (
|
||||||
optional_params.pop("vertex_credentials", None)
|
optional_params.pop("vertex_credentials", None)
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
or get_secret("VERTEX_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
response = vertex_ai.embedding(
|
response = vertex_ai.embedding(
|
||||||
|
|
|
@ -20,22 +20,20 @@ model_list:
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: together_ai/codellama/CodeLlama-13b-Instruct-hf
|
model: together_ai/codellama/CodeLlama-13b-Instruct-hf
|
||||||
model_name: CodeLlama-13b-Instruct
|
model_name: CodeLlama-13b-Instruct
|
||||||
router_settings:
|
|
||||||
num_retries: 0
|
|
||||||
enable_pre_call_checks: true
|
|
||||||
redis_host: os.environ/REDIS_HOST
|
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
|
||||||
redis_port: os.environ/REDIS_PORT
|
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
routing_strategy: "latency-based-routing"
|
redis_host: redis
|
||||||
|
# redis_password: <your redis password>
|
||||||
|
redis_port: 6379
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
set_verbose: True
|
||||||
|
# service_callback: ["prometheus_system"]
|
||||||
|
# success_callback: ["prometheus"]
|
||||||
|
# failure_callback: ["prometheus"]
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
enable_jwt_auth: True
|
||||||
alert_types: ["llm_exceptions", "daily_reports"]
|
disable_reset_budget: True
|
||||||
alerting_args:
|
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
daily_report_frequency: 60 # every minute
|
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||||
report_check_interval: 5 # every 5s
|
|
|
@ -156,6 +156,11 @@ class JWTHandler:
|
||||||
return public_key
|
return public_key
|
||||||
|
|
||||||
async def auth_jwt(self, token: str) -> dict:
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
|
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
||||||
|
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
||||||
|
# the key in different ways (e.g. HS* and RS*)."
|
||||||
|
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
||||||
|
|
||||||
audience = os.getenv("JWT_AUDIENCE")
|
audience = os.getenv("JWT_AUDIENCE")
|
||||||
decode_options = None
|
decode_options = None
|
||||||
if audience is None:
|
if audience is None:
|
||||||
|
@ -189,7 +194,7 @@ class JWTHandler:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
public_key_rsa, # type: ignore
|
public_key_rsa, # type: ignore
|
||||||
algorithms=["RS256"],
|
algorithms=algorithms,
|
||||||
options=decode_options,
|
options=decode_options,
|
||||||
audience=audience,
|
audience=audience,
|
||||||
)
|
)
|
||||||
|
@ -214,7 +219,7 @@ class JWTHandler:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
key,
|
key,
|
||||||
algorithms=["RS256"],
|
algorithms=algorithms,
|
||||||
audience=audience,
|
audience=audience,
|
||||||
options=decode_options
|
options=decode_options
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/any"
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
store_model_in_db: true
|
store_model_in_db: true
|
||||||
|
@ -17,4 +20,10 @@ general_settings:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]
|
failure_callback: ["langfuse"]
|
||||||
|
default_team_settings:
|
||||||
|
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
|
||||||
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
|
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
|
||||||
|
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"
|
||||||
|
|
|
@ -7795,11 +7795,15 @@ async def update_model(
|
||||||
)
|
)
|
||||||
async def model_info_v2(
|
async def model_info_v2(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
model: Optional[str] = fastapi.Query(
|
||||||
|
None, description="Specify the model name (optional)"
|
||||||
|
),
|
||||||
|
debug: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
|
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
|
||||||
"""
|
"""
|
||||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
||||||
|
|
||||||
if llm_model_list is None or not isinstance(llm_model_list, list):
|
if llm_model_list is None or not isinstance(llm_model_list, list):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -7822,19 +7826,35 @@ async def model_info_v2(
|
||||||
if len(user_api_key_dict.models) > 0:
|
if len(user_api_key_dict.models) > 0:
|
||||||
user_models = user_api_key_dict.models
|
user_models = user_api_key_dict.models
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
all_models = [m for m in all_models if m["model_name"] == model]
|
||||||
|
|
||||||
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json
|
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json
|
||||||
for model in all_models:
|
for _model in all_models:
|
||||||
# provided model_info in config.yaml
|
# provided model_info in config.yaml
|
||||||
model_info = model.get("model_info", {})
|
model_info = _model.get("model_info", {})
|
||||||
|
if debug == True:
|
||||||
|
_openai_client = "None"
|
||||||
|
if llm_router is not None:
|
||||||
|
_openai_client = (
|
||||||
|
llm_router._get_client(
|
||||||
|
deployment=_model, kwargs={}, client_type="async"
|
||||||
|
)
|
||||||
|
or "None"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_openai_client = "llm_router_is_None"
|
||||||
|
openai_client = str(_openai_client)
|
||||||
|
_model["openai_client"] = openai_client
|
||||||
|
|
||||||
# read litellm model_prices_and_context_window.json to get the following:
|
# read litellm model_prices_and_context_window.json to get the following:
|
||||||
# input_cost_per_token, output_cost_per_token, max_tokens
|
# input_cost_per_token, output_cost_per_token, max_tokens
|
||||||
litellm_model_info = get_litellm_model_info(model=model)
|
litellm_model_info = get_litellm_model_info(model=_model)
|
||||||
|
|
||||||
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
||||||
if litellm_model_info == {}:
|
if litellm_model_info == {}:
|
||||||
# use litellm_param model_name to get model_info
|
# use litellm_param model_name to get model_info
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = _model.get("litellm_params", {})
|
||||||
litellm_model = litellm_params.get("model", None)
|
litellm_model = litellm_params.get("model", None)
|
||||||
try:
|
try:
|
||||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||||
|
@ -7843,7 +7863,7 @@ async def model_info_v2(
|
||||||
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
||||||
if litellm_model_info == {}:
|
if litellm_model_info == {}:
|
||||||
# use litellm_param model_name to get model_info
|
# use litellm_param model_name to get model_info
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = _model.get("litellm_params", {})
|
||||||
litellm_model = litellm_params.get("model", None)
|
litellm_model = litellm_params.get("model", None)
|
||||||
split_model = litellm_model.split("/")
|
split_model = litellm_model.split("/")
|
||||||
if len(split_model) > 0:
|
if len(split_model) > 0:
|
||||||
|
@ -7855,10 +7875,10 @@ async def model_info_v2(
|
||||||
for k, v in litellm_model_info.items():
|
for k, v in litellm_model_info.items():
|
||||||
if k not in model_info:
|
if k not in model_info:
|
||||||
model_info[k] = v
|
model_info[k] = v
|
||||||
model["model_info"] = model_info
|
_model["model_info"] = model_info
|
||||||
# don't return the api key / vertex credentials
|
# don't return the api key / vertex credentials
|
||||||
model["litellm_params"].pop("api_key", None)
|
_model["litellm_params"].pop("api_key", None)
|
||||||
model["litellm_params"].pop("vertex_credentials", None)
|
_model["litellm_params"].pop("vertex_credentials", None)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||||
return {"data": all_models}
|
return {"data": all_models}
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
import copy, httpx
|
import copy, httpx
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai, hashlib, json
|
import litellm, openai, hashlib, json
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -102,6 +103,7 @@ class Router:
|
||||||
"usage-based-routing",
|
"usage-based-routing",
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
"cost-based-routing",
|
"cost-based-routing",
|
||||||
|
"usage-based-routing-v2",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
semaphore: Optional[asyncio.Semaphore] = None,
|
semaphore: Optional[asyncio.Semaphore] = None,
|
||||||
|
@ -2114,6 +2116,10 @@ class Router:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||||
)
|
)
|
||||||
|
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||||
|
if azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = "2023-07-01-preview"
|
api_version = "2023-07-01-preview"
|
||||||
|
|
||||||
|
@ -2125,6 +2131,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_async_client"
|
cache_key = f"{model_id}_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI(
|
_client = openai.AsyncAzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -2149,6 +2156,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -2173,6 +2181,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
@ -2197,6 +2206,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
@ -2229,6 +2239,7 @@ class Router:
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
|
"azure_ad_token": azure_ad_token,
|
||||||
}
|
}
|
||||||
from litellm.llms.azure import select_azure_base_url_or_endpoint
|
from litellm.llms.azure import select_azure_base_url_or_endpoint
|
||||||
|
|
||||||
|
@ -2557,20 +2568,27 @@ class Router:
|
||||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||||
|
|
||||||
# set region (if azure model)
|
# set region (if azure model)
|
||||||
try:
|
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False)
|
||||||
if "azure" in deployment.litellm_params.model:
|
if _auto_infer_region == True or _auto_infer_region == "True":
|
||||||
region = litellm.utils.get_model_region(
|
print("Auto inferring region") # noqa
|
||||||
litellm_params=deployment.litellm_params, mode=None
|
"""
|
||||||
)
|
Hiding behind a feature flag
|
||||||
|
When there is a large amount of LLM deployments this makes startup times blow up
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "azure" in deployment.litellm_params.model:
|
||||||
|
region = litellm.utils.get_model_region(
|
||||||
|
litellm_params=deployment.litellm_params, mode=None
|
||||||
|
)
|
||||||
|
|
||||||
deployment.litellm_params.region_name = region
|
deployment.litellm_params.region_name = region
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.error(
|
verbose_router_logger.error(
|
||||||
"Unable to get the region for azure model - {}, {}".format(
|
"Unable to get the region for azure model - {}, {}".format(
|
||||||
deployment.litellm_params.model, str(e)
|
deployment.litellm_params.model, str(e)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
pass # [NON-BLOCKING]
|
||||||
pass # [NON-BLOCKING]
|
|
||||||
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
|
@ -2599,7 +2617,7 @@ class Router:
|
||||||
self.model_names.append(deployment.model_name)
|
self.model_names.append(deployment.model_name)
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def upsert_deployment(self, deployment: Deployment) -> Deployment:
|
def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||||
"""
|
"""
|
||||||
Add or update deployment
|
Add or update deployment
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -2609,8 +2627,17 @@ class Router:
|
||||||
- The added/updated deployment
|
- The added/updated deployment
|
||||||
"""
|
"""
|
||||||
# check if deployment already exists
|
# check if deployment already exists
|
||||||
|
_deployment_model_id = deployment.model_info.id or ""
|
||||||
|
_deployment_on_router: Optional[Deployment] = self.get_deployment(
|
||||||
|
model_id=_deployment_model_id
|
||||||
|
)
|
||||||
|
if _deployment_on_router is not None:
|
||||||
|
# deployment with this model_id exists on the router
|
||||||
|
if deployment.litellm_params == _deployment_on_router.litellm_params:
|
||||||
|
# No need to update
|
||||||
|
return None
|
||||||
|
|
||||||
if deployment.model_info.id in self.get_model_ids():
|
# if there is a new litellm param -> then update the deployment
|
||||||
# remove the previous deployment
|
# remove the previous deployment
|
||||||
removal_idx: Optional[int] = None
|
removal_idx: Optional[int] = None
|
||||||
for idx, model in enumerate(self.model_list):
|
for idx, model in enumerate(self.model_list):
|
||||||
|
@ -2619,16 +2646,9 @@ class Router:
|
||||||
|
|
||||||
if removal_idx is not None:
|
if removal_idx is not None:
|
||||||
self.model_list.pop(removal_idx)
|
self.model_list.pop(removal_idx)
|
||||||
|
else:
|
||||||
# add to model list
|
# if the model_id is not in router
|
||||||
_deployment = deployment.to_json(exclude_none=True)
|
self.add_deployment(deployment=deployment)
|
||||||
self.model_list.append(_deployment)
|
|
||||||
|
|
||||||
# initialize client
|
|
||||||
self._add_deployment(deployment=deployment)
|
|
||||||
|
|
||||||
# add to model names
|
|
||||||
self.model_names.append(deployment.model_name)
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
||||||
|
@ -2989,11 +3009,15 @@ class Router:
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
input: Optional[Union[str, List]] = None,
|
input: Optional[Union[str, List]] = None,
|
||||||
specific_deployment: Optional[bool] = False,
|
specific_deployment: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[str, Union[list, dict]]:
|
||||||
"""
|
"""
|
||||||
Common checks for 'get_available_deployment' across sync + async call.
|
Common checks for 'get_available_deployment' across sync + async call.
|
||||||
|
|
||||||
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- Dict, if specific model chosen
|
||||||
|
- List, if multiple models chosen
|
||||||
"""
|
"""
|
||||||
# check if aliases set on litellm model alias map
|
# check if aliases set on litellm model alias map
|
||||||
if specific_deployment == True:
|
if specific_deployment == True:
|
||||||
|
@ -3003,7 +3027,7 @@ class Router:
|
||||||
if deployment_model == model:
|
if deployment_model == model:
|
||||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||||
# return the first deployment where the `model` matches the specificed deployment name
|
# return the first deployment where the `model` matches the specificed deployment name
|
||||||
return deployment, None
|
return deployment_model, deployment
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
||||||
)
|
)
|
||||||
|
@ -3019,7 +3043,7 @@ class Router:
|
||||||
self.default_deployment
|
self.default_deployment
|
||||||
) # self.default_deployment
|
) # self.default_deployment
|
||||||
updated_deployment["litellm_params"]["model"] = model
|
updated_deployment["litellm_params"]["model"] = model
|
||||||
return updated_deployment, None
|
return model, updated_deployment
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
|
@ -3072,10 +3096,10 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=specific_deployment,
|
specific_deployment=specific_deployment,
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
if healthy_deployments is None:
|
if isinstance(healthy_deployments, dict):
|
||||||
return model
|
return healthy_deployments
|
||||||
|
|
||||||
# filter out the deployments currently cooling down
|
# filter out the deployments currently cooling down
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
|
@ -3131,7 +3155,7 @@ class Router:
|
||||||
):
|
):
|
||||||
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
||||||
model_group=model,
|
model_group=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
@ -3141,7 +3165,7 @@ class Router:
|
||||||
):
|
):
|
||||||
deployment = await self.lowestcost_logger.async_get_available_deployments(
|
deployment = await self.lowestcost_logger.async_get_available_deployments(
|
||||||
model_group=model,
|
model_group=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
@ -3219,8 +3243,8 @@ class Router:
|
||||||
specific_deployment=specific_deployment,
|
specific_deployment=specific_deployment,
|
||||||
)
|
)
|
||||||
|
|
||||||
if healthy_deployments is None:
|
if isinstance(healthy_deployments, dict):
|
||||||
return model
|
return healthy_deployments
|
||||||
|
|
||||||
# filter out the deployments currently cooling down
|
# filter out the deployments currently cooling down
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
|
@ -3244,7 +3268,7 @@ class Router:
|
||||||
|
|
||||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
deployment = self.leastbusy_logger.get_available_deployments(
|
deployment = self.leastbusy_logger.get_available_deployments(
|
||||||
model_group=model, healthy_deployments=healthy_deployments
|
model_group=model, healthy_deployments=healthy_deployments # type: ignore
|
||||||
)
|
)
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||||
|
@ -3292,7 +3316,7 @@ class Router:
|
||||||
):
|
):
|
||||||
deployment = self.lowestlatency_logger.get_available_deployments(
|
deployment = self.lowestlatency_logger.get_available_deployments(
|
||||||
model_group=model,
|
model_group=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -3301,7 +3325,7 @@ class Router:
|
||||||
):
|
):
|
||||||
deployment = self.lowesttpm_logger.get_available_deployments(
|
deployment = self.lowesttpm_logger.get_available_deployments(
|
||||||
model_group=model,
|
model_group=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
@ -3311,7 +3335,7 @@ class Router:
|
||||||
):
|
):
|
||||||
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
||||||
model_group=model,
|
model_group=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
|
|
@ -113,6 +113,49 @@ async def get_response():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
except litellm.UnprocessableEntityError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_router_response():
|
||||||
|
model = "claude-3-sonnet@20240229"
|
||||||
|
vertex_ai_project = "adroit-crow-413218"
|
||||||
|
vertex_ai_location = "asia-southeast1"
|
||||||
|
json_obj = get_vertex_ai_creds_json()
|
||||||
|
vertex_credentials = json.dumps(json_obj)
|
||||||
|
|
||||||
|
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||||
|
try:
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "sonnet",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "vertex_ai/claude-3-sonnet@20240229",
|
||||||
|
"vertex_ai_project": vertex_ai_project,
|
||||||
|
"vertex_ai_location": vertex_ai_location,
|
||||||
|
"vertex_credentials": vertex_credentials,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="sonnet",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n\nResponse: {response}\n\n")
|
||||||
|
|
||||||
except litellm.UnprocessableEntityError as e:
|
except litellm.UnprocessableEntityError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -547,47 +590,37 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
def test_gemini_pro_function_calling():
|
def test_gemini_pro_function_calling():
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
response = litellm.completion(
|
||||||
{
|
model="vertex_ai/gemini-pro",
|
||||||
"type": "function",
|
messages=[
|
||||||
"function": {
|
{
|
||||||
"name": "get_current_weather",
|
"role": "user",
|
||||||
"description": "Get the current weather in a given location",
|
"content": "Call the submit_cities function with San Francisco and New York",
|
||||||
"parameters": {
|
}
|
||||||
"type": "object",
|
],
|
||||||
"properties": {
|
tools=[
|
||||||
"location": {
|
{
|
||||||
"type": "string",
|
"type": "function",
|
||||||
"description": "The city and state, e.g. San Francisco, CA",
|
"function": {
|
||||||
},
|
"name": "submit_cities",
|
||||||
"unit": {
|
"description": "Submits a list of cities",
|
||||||
"type": "string",
|
"parameters": {
|
||||||
"enum": ["celsius", "fahrenheit"],
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"cities": {"type": "array", "items": {"type": "string"}}
|
||||||
},
|
},
|
||||||
|
"required": ["cities"],
|
||||||
},
|
},
|
||||||
"required": ["location"],
|
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
],
|
||||||
]
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What's the weather like in Boston today in fahrenheit?",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
completion = litellm.completion(
|
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
|
||||||
# assert completion.choices[0].message.content is None ## GEMINI PRO is very chatty.
|
print(f"response: {response}")
|
||||||
if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
|
|
||||||
completion.choices[0].message.tool_calls, list
|
|
||||||
):
|
|
||||||
assert len(completion.choices[0].message.tool_calls) == 1
|
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
@ -596,7 +629,7 @@ def test_gemini_pro_function_calling():
|
||||||
if "429 Quota exceeded" in str(e):
|
if "429 Quota exceeded" in str(e):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return
|
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
# gemini_pro_function_calling()
|
# gemini_pro_function_calling()
|
||||||
|
|
|
@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth():
|
||||||
|
|
||||||
# test_completion_bedrock_claude_sts_client_auth()
|
# test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set")
|
||||||
|
def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
|
print("\ncalling bedrock claude with oidc auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_web_identity_token = "oidc/circleci_v2/"
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_extra_headers():
|
def test_bedrock_extra_headers():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,6 +13,7 @@ import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
# litellm.num_retries=3
|
# litellm.num_retries=3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -96,7 +97,6 @@ async def test_completion_predibase(sync_mode):
|
||||||
response = completion(
|
response = completion(
|
||||||
model="predibase/llama-3-8b-instruct",
|
model="predibase/llama-3-8b-instruct",
|
||||||
tenant_id="c4768f95",
|
tenant_id="c4768f95",
|
||||||
api_base="https://serving.app.predibase.com",
|
|
||||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
)
|
)
|
||||||
|
@ -1138,7 +1138,7 @@ def test_get_hf_task_for_model():
|
||||||
model = "roneneldan/TinyStories-3M"
|
model = "roneneldan/TinyStories-3M"
|
||||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == None
|
assert model_type == "text-generation"
|
||||||
|
|
||||||
|
|
||||||
# test_get_hf_task_for_model()
|
# test_get_hf_task_for_model()
|
||||||
|
@ -1146,15 +1146,92 @@ def test_get_hf_task_for_model():
|
||||||
# ################### Hugging Face TGI models ########################
|
# ################### Hugging Face TGI models ########################
|
||||||
# # TGI model
|
# # TGI model
|
||||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||||
def hf_test_completion_tgi():
|
def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||||
# litellm.set_verbose=True
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"generated_text": "<|assistant|>\nI'm",
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"seed": None,
|
||||||
|
"prefill": [],
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28789,
|
||||||
|
"text": "<",
|
||||||
|
"logprob": -0.025222778,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.000003695488,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"text": "ass",
|
||||||
|
"logprob": -0.0000019073486,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11143,
|
||||||
|
"text": "istant",
|
||||||
|
"logprob": -0.000002026558,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.0000015497208,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28767,
|
||||||
|
"text": ">",
|
||||||
|
"logprob": -0.0000011920929,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"text": "\n",
|
||||||
|
"logprob": -0.00009703636,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||||
|
{
|
||||||
|
"id": 28742,
|
||||||
|
"text": "'",
|
||||||
|
"logprob": -0.88183594,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28719,
|
||||||
|
"text": "m",
|
||||||
|
"logprob": -0.00032639503,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_test_completion_tgi():
|
||||||
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
response = completion(
|
with patch("requests.post", side_effect=tgi_mock_post):
|
||||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
response = completion(
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
)
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
# Add any assertions here to check the response
|
max_tokens=10,
|
||||||
print(response)
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
except litellm.ServiceUnavailableError as e:
|
except litellm.ServiceUnavailableError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1192,6 +1269,40 @@ def hf_test_completion_tgi():
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
# hf_test_completion_none_task()
|
# hf_test_completion_none_task()
|
||||||
|
|
||||||
|
|
||||||
|
def mock_post(url, data=None, json=None, headers=None):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
[
|
||||||
|
{"label": "LABEL_0", "score": 0.9990691542625427},
|
||||||
|
{"label": "LABEL_1", "score": 0.0009308889275416732},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_classifier_task():
|
||||||
|
try:
|
||||||
|
with patch("requests.post", side_effect=mock_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
user_message = "I like you. I love you"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert isinstance(response.choices[0], litellm.Choices)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
assert isinstance(response.choices[0].message.content, str)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
########################### End of Hugging Face Tests ##############################################
|
########################### End of Hugging Face Tests ##############################################
|
||||||
# def test_completion_hf_api():
|
# def test_completion_hf_api():
|
||||||
# # failing on circle ci commenting out
|
# # failing on circle ci commenting out
|
||||||
|
|
|
@ -437,8 +437,9 @@ async def test_cost_tracking_with_caching():
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
caching=True,
|
caching=True,
|
||||||
|
mock_response="Hey, i'm doing well!",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1) # success callback is async
|
await asyncio.sleep(3) # success callback is async
|
||||||
response_cost = customHandler_optional_params.response_cost
|
response_cost = customHandler_optional_params.response_cost
|
||||||
assert response_cost > 0
|
assert response_cost > 0
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
|
|
|
@ -516,6 +516,23 @@ def test_voyage_embeddings():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triton_embeddings():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/my-triton-model",
|
||||||
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# stubbed endpoint is setup to return this
|
||||||
|
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_voyage_embeddings()
|
# test_voyage_embeddings()
|
||||||
# def test_xinference_embeddings():
|
# def test_xinference_embeddings():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -3,7 +3,27 @@ from litellm import get_optional_params
|
||||||
|
|
||||||
litellm.add_function_to_prompt = True
|
litellm.add_function_to_prompt = True
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
|
model="",
|
||||||
tool_choice= 'auto',
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
assert optional_params is not None
|
assert optional_params is not None
|
||||||
|
|
|
@ -11,7 +11,6 @@ litellm.failure_callback = ["lunary"]
|
||||||
litellm.success_callback = ["lunary"]
|
litellm.success_callback = ["lunary"]
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
|
||||||
def test_lunary_logging():
|
def test_lunary_logging():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
@ -59,9 +58,46 @@ def test_lunary_logging_with_metadata():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
#test_lunary_logging_with_metadata()
|
||||||
|
|
||||||
# test_lunary_logging_with_metadata()
|
def test_lunary_with_tools():
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo-1106",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto", # auto is default, but we'll be explicit
|
||||||
|
)
|
||||||
|
|
||||||
|
response_message = response.choices[0].message
|
||||||
|
print("\nLLM Response:\n", response.choices[0].message)
|
||||||
|
|
||||||
|
|
||||||
|
#test_lunary_with_tools()
|
||||||
|
|
||||||
def test_lunary_logging_with_streaming_and_metadata():
|
def test_lunary_logging_with_streaming_and_metadata():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
|
||||||
def test_azure_gpt_optional_params_gpt_vision():
|
def test_azure_gpt_optional_params_gpt_vision():
|
||||||
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="azure",
|
custom_llm_provider="azure",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
|
||||||
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
||||||
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
|
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="azure",
|
custom_llm_provider="azure",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
||||||
|
|
||||||
def test_openai_extra_headers():
|
def test_openai_extra_headers():
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
|
|
@ -754,6 +754,9 @@ async def test_async_fallbacks_max_retries_per_request():
|
||||||
|
|
||||||
def test_ausage_based_routing_fallbacks():
|
def test_ausage_based_routing_fallbacks():
|
||||||
try:
|
try:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
# [Prod Test]
|
# [Prod Test]
|
||||||
# IT tests Usage Based Routing with fallbacks
|
# IT tests Usage Based Routing with fallbacks
|
||||||
# The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4"
|
# The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4"
|
||||||
|
@ -766,10 +769,10 @@ def test_ausage_based_routing_fallbacks():
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Constants for TPM and RPM allocation
|
# Constants for TPM and RPM allocation
|
||||||
AZURE_FAST_RPM = 0
|
AZURE_FAST_RPM = 1
|
||||||
AZURE_BASIC_RPM = 0
|
AZURE_BASIC_RPM = 1
|
||||||
OPENAI_RPM = 0
|
OPENAI_RPM = 0
|
||||||
ANTHROPIC_RPM = 2
|
ANTHROPIC_RPM = 10
|
||||||
|
|
||||||
def get_azure_params(deployment_name: str):
|
def get_azure_params(deployment_name: str):
|
||||||
params = {
|
params = {
|
||||||
|
@ -832,9 +835,9 @@ def test_ausage_based_routing_fallbacks():
|
||||||
fallbacks=fallbacks_list,
|
fallbacks=fallbacks_list,
|
||||||
set_verbose=True,
|
set_verbose=True,
|
||||||
debug_level="DEBUG",
|
debug_level="DEBUG",
|
||||||
routing_strategy="usage-based-routing",
|
routing_strategy="usage-based-routing-v2",
|
||||||
redis_host=os.environ["REDIS_HOST"],
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
redis_port=os.environ["REDIS_PORT"],
|
redis_port=int(os.environ["REDIS_PORT"]),
|
||||||
num_retries=0,
|
num_retries=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -853,8 +856,8 @@ def test_ausage_based_routing_fallbacks():
|
||||||
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
|
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
|
||||||
assert response._hidden_params["model_id"] == "1"
|
assert response._hidden_params["model_id"] == "1"
|
||||||
|
|
||||||
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
|
for i in range(10):
|
||||||
for i in range(3):
|
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
|
||||||
response = router.completion(
|
response = router.completion(
|
||||||
model="azure/gpt-4-fast",
|
model="azure/gpt-4-fast",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -863,8 +866,7 @@ def test_ausage_based_routing_fallbacks():
|
||||||
)
|
)
|
||||||
print("response: ", response)
|
print("response: ", response)
|
||||||
print("response._hidden_params: ", response._hidden_params)
|
print("response._hidden_params: ", response._hidden_params)
|
||||||
if i == 2:
|
if i == 9:
|
||||||
# by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2
|
|
||||||
assert response._hidden_params["model_id"] == "4"
|
assert response._hidden_params["model_id"] == "4"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -23,3 +23,36 @@ def test_aws_secret_manager():
|
||||||
print(f"secret_val: {secret_val}")
|
print(f"secret_val: {secret_val}")
|
||||||
|
|
||||||
assert secret_val == "sk-1234"
|
assert secret_val == "sk-1234"
|
||||||
|
|
||||||
|
|
||||||
|
def redact_oidc_signature(secret_val):
|
||||||
|
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
|
||||||
|
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run")
|
||||||
|
def test_oidc_google():
|
||||||
|
secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions")
|
||||||
|
def test_oidc_github():
|
||||||
|
secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||||
|
def test_oidc_circleci():
|
||||||
|
secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||||
|
def test_oidc_circleci_v2():
|
||||||
|
secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
166
litellm/utils.py
166
litellm/utils.py
|
@ -33,6 +33,9 @@ from dataclasses import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
oidc_cache = DualCache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this works in python 3.8
|
# this works in python 3.8
|
||||||
|
@ -1079,6 +1082,7 @@ class Logging:
|
||||||
litellm_call_id,
|
litellm_call_id,
|
||||||
function_id,
|
function_id,
|
||||||
dynamic_success_callbacks=None,
|
dynamic_success_callbacks=None,
|
||||||
|
dynamic_failure_callbacks=None,
|
||||||
dynamic_async_success_callbacks=None,
|
dynamic_async_success_callbacks=None,
|
||||||
langfuse_public_key=None,
|
langfuse_public_key=None,
|
||||||
langfuse_secret=None,
|
langfuse_secret=None,
|
||||||
|
@ -1113,7 +1117,7 @@ class Logging:
|
||||||
self.sync_streaming_chunks = [] # for generating complete stream response
|
self.sync_streaming_chunks = [] # for generating complete stream response
|
||||||
self.model_call_details = {}
|
self.model_call_details = {}
|
||||||
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
|
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
|
||||||
self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call
|
self.dynamic_failure_callbacks = dynamic_failure_callbacks
|
||||||
self.dynamic_success_callbacks = (
|
self.dynamic_success_callbacks = (
|
||||||
dynamic_success_callbacks # callbacks set for just that call
|
dynamic_success_callbacks # callbacks set for just that call
|
||||||
)
|
)
|
||||||
|
@ -2334,11 +2338,26 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
callbacks = [] # init this to empty incase it's not created
|
||||||
|
|
||||||
|
if self.dynamic_failure_callbacks is not None and isinstance(
|
||||||
|
self.dynamic_failure_callbacks, list
|
||||||
|
):
|
||||||
|
callbacks = self.dynamic_failure_callbacks
|
||||||
|
## keep the internal functions ##
|
||||||
|
for callback in litellm.failure_callback:
|
||||||
|
if (
|
||||||
|
isinstance(callback, CustomLogger)
|
||||||
|
and "_PROXY_" in callback.__class__.__name__
|
||||||
|
):
|
||||||
|
callbacks.append(callback)
|
||||||
|
else:
|
||||||
|
callbacks = litellm.failure_callback
|
||||||
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
|
|
||||||
self.redact_message_input_output_from_logging(result=result)
|
self.redact_message_input_output_from_logging(result=result)
|
||||||
for callback in litellm.failure_callback:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
print_verbose("reaches lite_debugger for logging!")
|
print_verbose("reaches lite_debugger for logging!")
|
||||||
|
@ -2427,7 +2446,7 @@ class Logging:
|
||||||
)
|
)
|
||||||
elif callback == "langfuse":
|
elif callback == "langfuse":
|
||||||
global langFuseLogger
|
global langFuseLogger
|
||||||
verbose_logger.debug("reaches langfuse for logging!")
|
verbose_logger.debug("reaches langfuse for logging failure")
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for k, v in self.model_call_details.items():
|
for k, v in self.model_call_details.items():
|
||||||
if (
|
if (
|
||||||
|
@ -2436,8 +2455,16 @@ class Logging:
|
||||||
kwargs[k] = v
|
kwargs[k] = v
|
||||||
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
|
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
|
||||||
if langFuseLogger is None or (
|
if langFuseLogger is None or (
|
||||||
self.langfuse_public_key != langFuseLogger.public_key
|
(
|
||||||
and self.langfuse_secret != langFuseLogger.secret_key
|
self.langfuse_public_key is not None
|
||||||
|
and self.langfuse_public_key
|
||||||
|
!= langFuseLogger.public_key
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
self.langfuse_public_key is not None
|
||||||
|
and self.langfuse_public_key
|
||||||
|
!= langFuseLogger.public_key
|
||||||
|
)
|
||||||
):
|
):
|
||||||
langFuseLogger = LangFuseLogger(
|
langFuseLogger = LangFuseLogger(
|
||||||
langfuse_public_key=self.langfuse_public_key,
|
langfuse_public_key=self.langfuse_public_key,
|
||||||
|
@ -2713,6 +2740,7 @@ def function_setup(
|
||||||
### DYNAMIC CALLBACKS ###
|
### DYNAMIC CALLBACKS ###
|
||||||
dynamic_success_callbacks = None
|
dynamic_success_callbacks = None
|
||||||
dynamic_async_success_callbacks = None
|
dynamic_async_success_callbacks = None
|
||||||
|
dynamic_failure_callbacks = None
|
||||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||||
kwargs["success_callback"], list
|
kwargs["success_callback"], list
|
||||||
):
|
):
|
||||||
|
@ -2734,6 +2762,10 @@ def function_setup(
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
kwargs["success_callback"].pop(index)
|
kwargs["success_callback"].pop(index)
|
||||||
dynamic_success_callbacks = kwargs.pop("success_callback")
|
dynamic_success_callbacks = kwargs.pop("success_callback")
|
||||||
|
if kwargs.get("failure_callback", None) is not None and isinstance(
|
||||||
|
kwargs["failure_callback"], list
|
||||||
|
):
|
||||||
|
dynamic_failure_callbacks = kwargs.pop("failure_callback")
|
||||||
|
|
||||||
if add_breadcrumb:
|
if add_breadcrumb:
|
||||||
try:
|
try:
|
||||||
|
@ -2816,9 +2848,11 @@ def function_setup(
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||||
|
dynamic_failure_callbacks=dynamic_failure_callbacks,
|
||||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||||
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
|
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
|
||||||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
langfuse_secret=kwargs.pop("langfuse_secret", None)
|
||||||
|
or kwargs.pop("langfuse_secret_key", None),
|
||||||
)
|
)
|
||||||
## check if metadata is passed in
|
## check if metadata is passed in
|
||||||
litellm_params = {"api_base": ""}
|
litellm_params = {"api_base": ""}
|
||||||
|
@ -4783,6 +4817,12 @@ def get_optional_params_embeddings(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
|
if custom_llm_provider == "triton":
|
||||||
|
keys = list(non_default_params.keys())
|
||||||
|
for k in keys:
|
||||||
|
non_default_params.pop(k, None)
|
||||||
|
final_params = {**non_default_params, **kwargs}
|
||||||
|
return final_params
|
||||||
if custom_llm_provider == "vertex_ai":
|
if custom_llm_provider == "vertex_ai":
|
||||||
if len(non_default_params.keys()) > 0:
|
if len(non_default_params.keys()) > 0:
|
||||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
|
@ -4840,6 +4880,7 @@ def get_optional_params_embeddings(
|
||||||
def get_optional_params(
|
def get_optional_params(
|
||||||
# use the openai defaults
|
# use the openai defaults
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
model: str,
|
||||||
functions=None,
|
functions=None,
|
||||||
function_call=None,
|
function_call=None,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
|
@ -4853,7 +4894,6 @@ def get_optional_params(
|
||||||
frequency_penalty=None,
|
frequency_penalty=None,
|
||||||
logit_bias=None,
|
logit_bias=None,
|
||||||
user=None,
|
user=None,
|
||||||
model=None,
|
|
||||||
custom_llm_provider="",
|
custom_llm_provider="",
|
||||||
response_format=None,
|
response_format=None,
|
||||||
seed=None,
|
seed=None,
|
||||||
|
@ -4882,7 +4922,7 @@ def get_optional_params(
|
||||||
|
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
|
||||||
optional_params = {}
|
optional_params: Dict = {}
|
||||||
|
|
||||||
common_auth_dict = litellm.common_cloud_provider_auth_params
|
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||||
if custom_llm_provider in common_auth_dict["providers"]:
|
if custom_llm_provider in common_auth_dict["providers"]:
|
||||||
|
@ -5156,41 +5196,9 @@ def get_optional_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||||
if temperature is not None:
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
if temperature == 0.0 or temperature == 0:
|
)
|
||||||
# hugging face exception raised when temp==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
|
||||||
temperature = 0.01
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if n is not None:
|
|
||||||
optional_params["best_of"] = n
|
|
||||||
optional_params["do_sample"] = (
|
|
||||||
True # Need to sample if you want best of for hf inference endpoints
|
|
||||||
)
|
|
||||||
if stream is not None:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if max_tokens is not None:
|
|
||||||
# HF TGI raises the following exception when max_new_tokens==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
|
||||||
if max_tokens == 0:
|
|
||||||
max_tokens = 1
|
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
|
||||||
if n is not None:
|
|
||||||
optional_params["best_of"] = n
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["repetition_penalty"] = presence_penalty
|
|
||||||
if "echo" in passed_params:
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
|
||||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
|
||||||
optional_params["decoder_input_details"] = special_params["echo"]
|
|
||||||
passed_params.pop(
|
|
||||||
"echo", None
|
|
||||||
) # since we handle translating echo, we should not send it to TGI request
|
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -5769,9 +5777,7 @@ def get_optional_params(
|
||||||
extra_body # openai client supports `extra_body` param
|
extra_body # openai client supports `extra_body` param
|
||||||
)
|
)
|
||||||
else: # assume passing in params for openai/azure openai
|
else: # assume passing in params for openai/azure openai
|
||||||
print_verbose(
|
|
||||||
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
|
||||||
)
|
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider="openai"
|
model=model, custom_llm_provider="openai"
|
||||||
)
|
)
|
||||||
|
@ -6152,7 +6158,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"seed",
|
"seed",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -9408,6 +9414,72 @@ def get_secret(
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
|
|
||||||
|
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||||
|
if secret_name.startswith("oidc/"):
|
||||||
|
secret_name_split = secret_name.replace("oidc/", "")
|
||||||
|
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||||
|
# TODO: Add caching for HTTP requests
|
||||||
|
match oidc_provider:
|
||||||
|
case "google":
|
||||||
|
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||||
|
if oidc_token is not None:
|
||||||
|
return oidc_token
|
||||||
|
|
||||||
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||||
|
response = client.get(
|
||||||
|
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||||
|
params={"audience": oidc_aud},
|
||||||
|
headers={"Metadata-Flavor": "Google"},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
oidc_token = response.text
|
||||||
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||||
|
return oidc_token
|
||||||
|
else:
|
||||||
|
raise ValueError("Google OIDC provider failed")
|
||||||
|
case "circleci":
|
||||||
|
# https://circleci.com/docs/openid-connect-tokens/
|
||||||
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||||
|
if env_secret is None:
|
||||||
|
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||||
|
return env_secret
|
||||||
|
case "circleci_v2":
|
||||||
|
# https://circleci.com/docs/openid-connect-tokens/
|
||||||
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||||
|
if env_secret is None:
|
||||||
|
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||||
|
return env_secret
|
||||||
|
case "github":
|
||||||
|
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||||
|
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||||
|
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||||
|
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
||||||
|
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
|
||||||
|
|
||||||
|
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||||
|
if oidc_token is not None:
|
||||||
|
return oidc_token
|
||||||
|
|
||||||
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
response = client.get(
|
||||||
|
actions_id_token_request_url,
|
||||||
|
params={"audience": oidc_aud},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||||
|
"Accept": "application/json; api-version=2.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
oidc_token = response.text['value']
|
||||||
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||||
|
return oidc_token
|
||||||
|
else:
|
||||||
|
raise ValueError("Github OIDC provider failed")
|
||||||
|
case _:
|
||||||
|
raise ValueError("Unsupported OIDC provider")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if litellm.secret_manager_client is not None:
|
if litellm.secret_manager_client is not None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1571,6 +1571,135 @@
|
||||||
"litellm_provider": "replicate",
|
"litellm_provider": "replicate",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"openrouter/microsoft/wizardlm-2-8x22b:nitro": {
|
||||||
|
"max_tokens": 65536,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/google/gemini-pro-1.5": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 1000000,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.0000025,
|
||||||
|
"output_cost_per_token": 0.0000075,
|
||||||
|
"input_cost_per_image": 0.00265,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
|
"openrouter/mistralai/mixtral-8x22b-instruct": {
|
||||||
|
"max_tokens": 65536,
|
||||||
|
"input_cost_per_token": 0.00000065,
|
||||||
|
"output_cost_per_token": 0.00000065,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/cohere/command-r-plus": {
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/databricks/dbrx-instruct": {
|
||||||
|
"max_tokens": 32768,
|
||||||
|
"input_cost_per_token": 0.0000006,
|
||||||
|
"output_cost_per_token": 0.0000006,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/anthropic/claude-3-haiku": {
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.00000125,
|
||||||
|
"input_cost_per_image": 0.0004,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
|
"openrouter/anthropic/claude-3-sonnet": {
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"input_cost_per_image": 0.0048,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
|
"openrouter/mistralai/mistral-large": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/cognitivecomputations/dolphin-mixtral-8x7b": {
|
||||||
|
"max_tokens": 32769,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000005,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/google/gemini-pro-vision": {
|
||||||
|
"max_tokens": 45875,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000375,
|
||||||
|
"input_cost_per_image": 0.0025,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
|
"openrouter/fireworks/firellava-13b": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/meta-llama/llama-3-8b-instruct:free": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/meta-llama/llama-3-8b-instruct:extended": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.000000225,
|
||||||
|
"output_cost_per_token": 0.00000225,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/meta-llama/llama-3-70b-instruct:nitro": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.0000009,
|
||||||
|
"output_cost_per_token": 0.0000009,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/meta-llama/llama-3-70b-instruct": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00000059,
|
||||||
|
"output_cost_per_token": 0.00000079,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"openrouter/openai/gpt-4-vision-preview": {
|
||||||
|
"max_tokens": 130000,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"input_cost_per_image": 0.01445,
|
||||||
|
"litellm_provider": "openrouter",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
"openrouter/openai/gpt-3.5-turbo": {
|
"openrouter/openai/gpt-3.5-turbo": {
|
||||||
"max_tokens": 4095,
|
"max_tokens": 4095,
|
||||||
"input_cost_per_token": 0.0000015,
|
"input_cost_per_token": 0.0000015,
|
||||||
|
@ -1621,14 +1750,14 @@
|
||||||
"tool_use_system_prompt_tokens": 395
|
"tool_use_system_prompt_tokens": 395
|
||||||
},
|
},
|
||||||
"openrouter/google/palm-2-chat-bison": {
|
"openrouter/google/palm-2-chat-bison": {
|
||||||
"max_tokens": 8000,
|
"max_tokens": 25804,
|
||||||
"input_cost_per_token": 0.0000005,
|
"input_cost_per_token": 0.0000005,
|
||||||
"output_cost_per_token": 0.0000005,
|
"output_cost_per_token": 0.0000005,
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"openrouter/google/palm-2-codechat-bison": {
|
"openrouter/google/palm-2-codechat-bison": {
|
||||||
"max_tokens": 8000,
|
"max_tokens": 20070,
|
||||||
"input_cost_per_token": 0.0000005,
|
"input_cost_per_token": 0.0000005,
|
||||||
"output_cost_per_token": 0.0000005,
|
"output_cost_per_token": 0.0000005,
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
|
@ -1711,13 +1840,6 @@
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"openrouter/meta-llama/llama-3-70b-instruct": {
|
|
||||||
"max_tokens": 8192,
|
|
||||||
"input_cost_per_token": 0.0000008,
|
|
||||||
"output_cost_per_token": 0.0000008,
|
|
||||||
"litellm_provider": "openrouter",
|
|
||||||
"mode": "chat"
|
|
||||||
},
|
|
||||||
"j2-ultra": {
|
"j2-ultra": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
@ -3226,4 +3348,4 @@
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -92,10 +92,12 @@ litellm_settings:
|
||||||
default_team_settings:
|
default_team_settings:
|
||||||
- team_id: team-1
|
- team_id: team-1
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1
|
langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1
|
||||||
langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1
|
langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1
|
||||||
- team_id: team-2
|
- team_id: team-2
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
|
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
|
||||||
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.37.0"
|
version = "1.37.4"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.37.0"
|
version = "1.37.4"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -246,6 +246,33 @@ async def get_model_info_v2(session, key):
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_specific_model_info_v2(session, key, model_name):
|
||||||
|
url = "http://0.0.0.0:4000/v2/model/info?debug=True&model=" + model_name
|
||||||
|
print("running /model/info check for model=", model_name)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.get(url, headers=headers) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
print("response from v2/model/info")
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
_json_response = await response.json()
|
||||||
|
print("JSON response from /v2/model/info?model=", model_name, _json_response)
|
||||||
|
|
||||||
|
_model_info = _json_response["data"]
|
||||||
|
assert len(_model_info) == 1, f"Expected 1 model, got {len(_model_info)}"
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
return _model_info[0]
|
||||||
|
|
||||||
|
|
||||||
async def get_model_health(session, key, model_name):
|
async def get_model_health(session, key, model_name):
|
||||||
url = "http://0.0.0.0:4000/health?model=" + model_name
|
url = "http://0.0.0.0:4000/health?model=" + model_name
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -285,6 +312,11 @@ async def test_add_model_run_health():
|
||||||
model_name = f"azure-model-health-check-{model_id}"
|
model_name = f"azure-model-health-check-{model_id}"
|
||||||
print("adding model", model_name)
|
print("adding model", model_name)
|
||||||
await add_model_for_health_checking(session=session, model_id=model_id)
|
await add_model_for_health_checking(session=session, model_id=model_id)
|
||||||
|
_old_model_info = await get_specific_model_info_v2(
|
||||||
|
session=session, key=key, model_name=model_name
|
||||||
|
)
|
||||||
|
print("model info before test", _old_model_info)
|
||||||
|
|
||||||
await asyncio.sleep(30)
|
await asyncio.sleep(30)
|
||||||
print("calling /model/info")
|
print("calling /model/info")
|
||||||
await get_model_info(session=session, key=key)
|
await get_model_info(session=session, key=key)
|
||||||
|
@ -305,5 +337,28 @@ async def test_add_model_run_health():
|
||||||
_healthy_endpooint["model"] == "azure/chatgpt-v-2"
|
_healthy_endpooint["model"] == "azure/chatgpt-v-2"
|
||||||
) # this is the model that got added
|
) # this is the model that got added
|
||||||
|
|
||||||
|
# assert httpx client is is unchanges
|
||||||
|
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
_model_info_after_test = await get_specific_model_info_v2(
|
||||||
|
session=session, key=key, model_name=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
print("model info after test", _model_info_after_test)
|
||||||
|
old_openai_client = _old_model_info["openai_client"]
|
||||||
|
new_openai_client = _model_info_after_test["openai_client"]
|
||||||
|
print("old openai client", old_openai_client)
|
||||||
|
print("new openai client", new_openai_client)
|
||||||
|
|
||||||
|
"""
|
||||||
|
PROD TEST - This is extremly important
|
||||||
|
The OpenAI client used should be the same after 30 seconds
|
||||||
|
It is a serious bug if the openai client does not match here
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
old_openai_client == new_openai_client
|
||||||
|
), "OpenAI client does not match for the same model after 30 seconds"
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
await delete_model(session=session, model_id=model_id)
|
await delete_model(session=session, model_id=model_id)
|
||||||
|
|
|
@ -2,8 +2,17 @@
|
||||||
|
|
||||||
import React, { useState, useEffect, useRef } from "react";
|
import React, { useState, useEffect, useRef } from "react";
|
||||||
import { Button, TextInput, Grid, Col } from "@tremor/react";
|
import { Button, TextInput, Grid, Col } from "@tremor/react";
|
||||||
import { Card, Metric, Text, Title, Subtitle, Accordion, AccordionHeader, AccordionBody, } from "@tremor/react";
|
import {
|
||||||
import { CopyToClipboard } from 'react-copy-to-clipboard';
|
Card,
|
||||||
|
Metric,
|
||||||
|
Text,
|
||||||
|
Title,
|
||||||
|
Subtitle,
|
||||||
|
Accordion,
|
||||||
|
AccordionHeader,
|
||||||
|
AccordionBody,
|
||||||
|
} from "@tremor/react";
|
||||||
|
import { CopyToClipboard } from "react-copy-to-clipboard";
|
||||||
import {
|
import {
|
||||||
Button as Button2,
|
Button as Button2,
|
||||||
Modal,
|
Modal,
|
||||||
|
@ -13,7 +22,11 @@ import {
|
||||||
Select,
|
Select,
|
||||||
message,
|
message,
|
||||||
} from "antd";
|
} from "antd";
|
||||||
import { keyCreateCall, slackBudgetAlertsHealthCheck, modelAvailableCall } from "./networking";
|
import {
|
||||||
|
keyCreateCall,
|
||||||
|
slackBudgetAlertsHealthCheck,
|
||||||
|
modelAvailableCall,
|
||||||
|
} from "./networking";
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
|
|
||||||
|
@ -59,7 +72,11 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accessToken !== null) {
|
if (accessToken !== null) {
|
||||||
const model_available = await modelAvailableCall(accessToken, userID, userRole);
|
const model_available = await modelAvailableCall(
|
||||||
|
accessToken,
|
||||||
|
userID,
|
||||||
|
userRole
|
||||||
|
);
|
||||||
let available_model_names = model_available["data"].map(
|
let available_model_names = model_available["data"].map(
|
||||||
(element: { id: string }) => element.id
|
(element: { id: string }) => element.id
|
||||||
);
|
);
|
||||||
|
@ -70,12 +87,25 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
console.error("Error fetching user models:", error);
|
console.error("Error fetching user models:", error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fetchUserModels();
|
fetchUserModels();
|
||||||
}, [accessToken, userID, userRole]);
|
}, [accessToken, userID, userRole]);
|
||||||
|
|
||||||
const handleCreate = async (formValues: Record<string, any>) => {
|
const handleCreate = async (formValues: Record<string, any>) => {
|
||||||
try {
|
try {
|
||||||
|
const newKeyAlias = formValues?.key_alias ?? "";
|
||||||
|
const newKeyTeamId = formValues?.team_id ?? null;
|
||||||
|
const existingKeyAliases =
|
||||||
|
data
|
||||||
|
?.filter((k) => k.team_id === newKeyTeamId)
|
||||||
|
.map((k) => k.key_alias) ?? [];
|
||||||
|
|
||||||
|
if (existingKeyAliases.includes(newKeyAlias)) {
|
||||||
|
throw new Error(
|
||||||
|
`Key alias ${newKeyAlias} already exists for team with ID ${newKeyTeamId}, please provide another key alias`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
message.info("Making API Call");
|
message.info("Making API Call");
|
||||||
setIsModalVisible(true);
|
setIsModalVisible(true);
|
||||||
const response = await keyCreateCall(accessToken, userID, formValues);
|
const response = await keyCreateCall(accessToken, userID, formValues);
|
||||||
|
@ -89,12 +119,13 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
localStorage.removeItem("userData" + userID);
|
localStorage.removeItem("userData" + userID);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error creating the key:", error);
|
console.error("Error creating the key:", error);
|
||||||
|
message.error(`Error creating the key: ${error}`, 20);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleCopy = () => {
|
const handleCopy = () => {
|
||||||
message.success('API Key copied to clipboard');
|
message.success("API Key copied to clipboard");
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let tempModelsToPick = [];
|
let tempModelsToPick = [];
|
||||||
|
@ -119,7 +150,6 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
|
|
||||||
setModelsToPick(tempModelsToPick);
|
setModelsToPick(tempModelsToPick);
|
||||||
}, [team, userModels]);
|
}, [team, userModels]);
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
|
@ -141,140 +171,164 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
wrapperCol={{ span: 16 }}
|
wrapperCol={{ span: 16 }}
|
||||||
labelAlign="left"
|
labelAlign="left"
|
||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="Key Name"
|
label="Key Name"
|
||||||
name="key_alias"
|
name="key_alias"
|
||||||
rules={[{ required: true, message: 'Please input a key name' }]}
|
rules={[{ required: true, message: "Please input a key name" }]}
|
||||||
help="required"
|
help="required"
|
||||||
>
|
>
|
||||||
<TextInput placeholder="" />
|
<TextInput placeholder="" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="Team ID"
|
label="Team ID"
|
||||||
name="team_id"
|
name="team_id"
|
||||||
hidden={true}
|
hidden={true}
|
||||||
initialValue={team ? team["team_id"] : null}
|
initialValue={team ? team["team_id"] : null}
|
||||||
valuePropName="team_id"
|
valuePropName="team_id"
|
||||||
className="mt-8"
|
|
||||||
>
|
|
||||||
<Input value={team ? team["team_alias"] : ""} disabled />
|
|
||||||
</Form.Item>
|
|
||||||
|
|
||||||
<Form.Item
|
|
||||||
label="Models"
|
|
||||||
name="models"
|
|
||||||
rules={[{ required: true, message: 'Please select a model' }]}
|
|
||||||
help="required"
|
|
||||||
>
|
|
||||||
<Select
|
|
||||||
mode="multiple"
|
|
||||||
placeholder="Select models"
|
|
||||||
style={{ width: "100%" }}
|
|
||||||
onChange={(values) => {
|
|
||||||
// Check if "All Team Models" is selected
|
|
||||||
const isAllTeamModelsSelected = values.includes("all-team-models");
|
|
||||||
|
|
||||||
// If "All Team Models" is selected, deselect all other models
|
|
||||||
if (isAllTeamModelsSelected) {
|
|
||||||
const newValues = ["all-team-models"];
|
|
||||||
// You can call the form's setFieldsValue method to update the value
|
|
||||||
form.setFieldsValue({ models: newValues });
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Option key="all-team-models" value="all-team-models">
|
|
||||||
All Team Models
|
|
||||||
</Option>
|
|
||||||
{
|
|
||||||
modelsToPick.map((model: string) => (
|
|
||||||
(
|
|
||||||
<Option key={model} value={model}>
|
|
||||||
{model}
|
|
||||||
</Option>
|
|
||||||
)
|
|
||||||
))
|
|
||||||
}
|
|
||||||
</Select>
|
|
||||||
</Form.Item>
|
|
||||||
<Accordion className="mt-20 mb-8" >
|
|
||||||
<AccordionHeader>
|
|
||||||
<b>Optional Settings</b>
|
|
||||||
</AccordionHeader>
|
|
||||||
<AccordionBody>
|
|
||||||
<Form.Item
|
|
||||||
className="mt-8"
|
|
||||||
label="Max Budget (USD)"
|
|
||||||
name="max_budget"
|
|
||||||
help={`Budget cannot exceed team max budget: $${team?.max_budget !== null && team?.max_budget !== undefined ? team?.max_budget : 'unlimited'}`}
|
|
||||||
rules={[
|
|
||||||
{
|
|
||||||
validator: async (_, value) => {
|
|
||||||
if (value && team && team.max_budget !== null && value > team.max_budget) {
|
|
||||||
throw new Error(`Budget cannot exceed team max budget: $${team.max_budget}`);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]}
|
|
||||||
>
|
|
||||||
<InputNumber step={0.01} precision={2} width={200} />
|
|
||||||
</Form.Item>
|
|
||||||
<Form.Item
|
|
||||||
className="mt-8"
|
className="mt-8"
|
||||||
label="Reset Budget"
|
>
|
||||||
name="budget_duration"
|
<Input value={team ? team["team_alias"] : ""} disabled />
|
||||||
help={`Team Reset Budget: ${team?.budget_duration !== null && team?.budget_duration !== undefined ? team?.budget_duration : 'None'}`}
|
</Form.Item>
|
||||||
>
|
|
||||||
<Select defaultValue={null} placeholder="n/a">
|
<Form.Item
|
||||||
<Select.Option value="24h">daily</Select.Option>
|
label="Models"
|
||||||
<Select.Option value="30d">monthly</Select.Option>
|
name="models"
|
||||||
</Select>
|
rules={[{ required: true, message: "Please select a model" }]}
|
||||||
</Form.Item>
|
help="required"
|
||||||
<Form.Item
|
>
|
||||||
className="mt-8"
|
<Select
|
||||||
label="Tokens per minute Limit (TPM)"
|
mode="multiple"
|
||||||
name="tpm_limit"
|
placeholder="Select models"
|
||||||
help={`TPM cannot exceed team TPM limit: ${team?.tpm_limit !== null && team?.tpm_limit !== undefined ? team?.tpm_limit : 'unlimited'}`}
|
style={{ width: "100%" }}
|
||||||
rules={[
|
onChange={(values) => {
|
||||||
{
|
// Check if "All Team Models" is selected
|
||||||
validator: async (_, value) => {
|
const isAllTeamModelsSelected =
|
||||||
if (value && team && team.tpm_limit !== null && value > team.tpm_limit) {
|
values.includes("all-team-models");
|
||||||
throw new Error(`TPM limit cannot exceed team TPM limit: ${team.tpm_limit}`);
|
|
||||||
}
|
// If "All Team Models" is selected, deselect all other models
|
||||||
},
|
if (isAllTeamModelsSelected) {
|
||||||
},
|
const newValues = ["all-team-models"];
|
||||||
]}
|
// You can call the form's setFieldsValue method to update the value
|
||||||
>
|
form.setFieldsValue({ models: newValues });
|
||||||
<InputNumber step={1} width={400} />
|
}
|
||||||
</Form.Item>
|
}}
|
||||||
<Form.Item
|
>
|
||||||
className="mt-8"
|
<Option key="all-team-models" value="all-team-models">
|
||||||
label="Requests per minute Limit (RPM)"
|
All Team Models
|
||||||
name="rpm_limit"
|
</Option>
|
||||||
help={`RPM cannot exceed team RPM limit: ${team?.rpm_limit !== null && team?.rpm_limit !== undefined ? team?.rpm_limit : 'unlimited'}`}
|
{modelsToPick.map((model: string) => (
|
||||||
rules={[
|
<Option key={model} value={model}>
|
||||||
{
|
{model}
|
||||||
validator: async (_, value) => {
|
</Option>
|
||||||
if (value && team && team.rpm_limit !== null && value > team.rpm_limit) {
|
))}
|
||||||
throw new Error(`RPM limit cannot exceed team RPM limit: ${team.rpm_limit}`);
|
</Select>
|
||||||
}
|
</Form.Item>
|
||||||
},
|
<Accordion className="mt-20 mb-8">
|
||||||
},
|
<AccordionHeader>
|
||||||
]}
|
<b>Optional Settings</b>
|
||||||
>
|
</AccordionHeader>
|
||||||
<InputNumber step={1} width={400} />
|
<AccordionBody>
|
||||||
</Form.Item>
|
<Form.Item
|
||||||
<Form.Item label="Expire Key (eg: 30s, 30h, 30d)" name="duration" className="mt-8">
|
className="mt-8"
|
||||||
<TextInput placeholder="" />
|
label="Max Budget (USD)"
|
||||||
</Form.Item>
|
name="max_budget"
|
||||||
<Form.Item label="Metadata" name="metadata">
|
help={`Budget cannot exceed team max budget: $${team?.max_budget !== null && team?.max_budget !== undefined ? team?.max_budget : "unlimited"}`}
|
||||||
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
|
rules={[
|
||||||
</Form.Item>
|
{
|
||||||
|
validator: async (_, value) => {
|
||||||
|
if (
|
||||||
|
value &&
|
||||||
|
team &&
|
||||||
|
team.max_budget !== null &&
|
||||||
|
value > team.max_budget
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
`Budget cannot exceed team max budget: $${team.max_budget}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<InputNumber step={0.01} precision={2} width={200} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
className="mt-8"
|
||||||
|
label="Reset Budget"
|
||||||
|
name="budget_duration"
|
||||||
|
help={`Team Reset Budget: ${team?.budget_duration !== null && team?.budget_duration !== undefined ? team?.budget_duration : "None"}`}
|
||||||
|
>
|
||||||
|
<Select defaultValue={null} placeholder="n/a">
|
||||||
|
<Select.Option value="24h">daily</Select.Option>
|
||||||
|
<Select.Option value="30d">monthly</Select.Option>
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
className="mt-8"
|
||||||
|
label="Tokens per minute Limit (TPM)"
|
||||||
|
name="tpm_limit"
|
||||||
|
help={`TPM cannot exceed team TPM limit: ${team?.tpm_limit !== null && team?.tpm_limit !== undefined ? team?.tpm_limit : "unlimited"}`}
|
||||||
|
rules={[
|
||||||
|
{
|
||||||
|
validator: async (_, value) => {
|
||||||
|
if (
|
||||||
|
value &&
|
||||||
|
team &&
|
||||||
|
team.tpm_limit !== null &&
|
||||||
|
value > team.tpm_limit
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
`TPM limit cannot exceed team TPM limit: ${team.tpm_limit}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<InputNumber step={1} width={400} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
className="mt-8"
|
||||||
|
label="Requests per minute Limit (RPM)"
|
||||||
|
name="rpm_limit"
|
||||||
|
help={`RPM cannot exceed team RPM limit: ${team?.rpm_limit !== null && team?.rpm_limit !== undefined ? team?.rpm_limit : "unlimited"}`}
|
||||||
|
rules={[
|
||||||
|
{
|
||||||
|
validator: async (_, value) => {
|
||||||
|
if (
|
||||||
|
value &&
|
||||||
|
team &&
|
||||||
|
team.rpm_limit !== null &&
|
||||||
|
value > team.rpm_limit
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
`RPM limit cannot exceed team RPM limit: ${team.rpm_limit}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<InputNumber step={1} width={400} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="Expire Key (eg: 30s, 30h, 30d)"
|
||||||
|
name="duration"
|
||||||
|
className="mt-8"
|
||||||
|
>
|
||||||
|
<TextInput placeholder="" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item label="Metadata" name="metadata">
|
||||||
|
<Input.TextArea
|
||||||
|
rows={4}
|
||||||
|
placeholder="Enter metadata as JSON"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</AccordionBody>
|
||||||
|
</Accordion>
|
||||||
|
</>
|
||||||
|
|
||||||
</AccordionBody>
|
|
||||||
</Accordion>
|
|
||||||
</>
|
|
||||||
|
|
||||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
<Button2 htmlType="submit">Create Key</Button2>
|
<Button2 htmlType="submit">Create Key</Button2>
|
||||||
</div>
|
</div>
|
||||||
|
@ -288,36 +342,45 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
footer={null}
|
footer={null}
|
||||||
>
|
>
|
||||||
<Grid numItems={1} className="gap-2 w-full">
|
<Grid numItems={1} className="gap-2 w-full">
|
||||||
|
<Title>Save your Key</Title>
|
||||||
<Title>Save your Key</Title>
|
<Col numColSpan={1}>
|
||||||
<Col numColSpan={1}>
|
<p>
|
||||||
<p>
|
Please save this secret key somewhere safe and accessible. For
|
||||||
Please save this secret key somewhere safe and accessible. For
|
security reasons, <b>you will not be able to view it again</b>{" "}
|
||||||
security reasons, <b>you will not be able to view it again</b>{" "}
|
through your LiteLLM account. If you lose this secret key, you
|
||||||
through your LiteLLM account. If you lose this secret key, you
|
will need to generate a new one.
|
||||||
will need to generate a new one.
|
</p>
|
||||||
</p>
|
</Col>
|
||||||
</Col>
|
<Col numColSpan={1}>
|
||||||
<Col numColSpan={1}>
|
{apiKey != null ? (
|
||||||
{apiKey != null ? (
|
<div>
|
||||||
<div>
|
|
||||||
<Text className="mt-3">API Key:</Text>
|
<Text className="mt-3">API Key:</Text>
|
||||||
<div style={{ background: '#f8f8f8', padding: '10px', borderRadius: '5px', marginBottom: '10px' }}>
|
<div
|
||||||
<pre style={{ wordWrap: 'break-word', whiteSpace: 'normal' }}>{apiKey}</pre>
|
style={{
|
||||||
</div>
|
background: "#f8f8f8",
|
||||||
|
padding: "10px",
|
||||||
<CopyToClipboard text={apiKey} onCopy={handleCopy}>
|
borderRadius: "5px",
|
||||||
|
marginBottom: "10px",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<pre
|
||||||
|
style={{ wordWrap: "break-word", whiteSpace: "normal" }}
|
||||||
|
>
|
||||||
|
{apiKey}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<CopyToClipboard text={apiKey} onCopy={handleCopy}>
|
||||||
<Button className="mt-3">Copy API Key</Button>
|
<Button className="mt-3">Copy API Key</Button>
|
||||||
</CopyToClipboard>
|
</CopyToClipboard>
|
||||||
{/* <Button className="mt-3" onClick={sendSlackAlert}>
|
{/* <Button className="mt-3" onClick={sendSlackAlert}>
|
||||||
Test Key
|
Test Key
|
||||||
</Button> */}
|
</Button> */}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Text>Key being created, this might take 30s</Text>
|
<Text>Key being created, this might take 30s</Text>
|
||||||
)}
|
)}
|
||||||
</Col>
|
</Col>
|
||||||
|
|
||||||
</Grid>
|
</Grid>
|
||||||
</Modal>
|
</Modal>
|
||||||
)}
|
)}
|
||||||
|
|
|
@ -2,7 +2,13 @@ import React, { useState, useEffect } from "react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { Typography } from "antd";
|
import { Typography } from "antd";
|
||||||
import { teamDeleteCall, teamUpdateCall, teamInfoCall } from "./networking";
|
import { teamDeleteCall, teamUpdateCall, teamInfoCall } from "./networking";
|
||||||
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon } from "@heroicons/react/outline";
|
import {
|
||||||
|
InformationCircleIcon,
|
||||||
|
PencilAltIcon,
|
||||||
|
PencilIcon,
|
||||||
|
StatusOnlineIcon,
|
||||||
|
TrashIcon,
|
||||||
|
} from "@heroicons/react/outline";
|
||||||
import {
|
import {
|
||||||
Button as Button2,
|
Button as Button2,
|
||||||
Modal,
|
Modal,
|
||||||
|
@ -46,8 +52,12 @@ interface EditTeamModalProps {
|
||||||
onSubmit: (data: FormData) => void; // Assuming FormData is the type of data to be submitted
|
onSubmit: (data: FormData) => void; // Assuming FormData is the type of data to be submitted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import {
|
||||||
import { teamCreateCall, teamMemberAddCall, Member, modelAvailableCall } from "./networking";
|
teamCreateCall,
|
||||||
|
teamMemberAddCall,
|
||||||
|
Member,
|
||||||
|
modelAvailableCall,
|
||||||
|
} from "./networking";
|
||||||
|
|
||||||
const Team: React.FC<TeamProps> = ({
|
const Team: React.FC<TeamProps> = ({
|
||||||
teams,
|
teams,
|
||||||
|
@ -63,7 +73,6 @@ const Team: React.FC<TeamProps> = ({
|
||||||
const [value, setValue] = useState("");
|
const [value, setValue] = useState("");
|
||||||
const [editModalVisible, setEditModalVisible] = useState(false);
|
const [editModalVisible, setEditModalVisible] = useState(false);
|
||||||
|
|
||||||
|
|
||||||
const [selectedTeam, setSelectedTeam] = useState<null | any>(
|
const [selectedTeam, setSelectedTeam] = useState<null | any>(
|
||||||
teams ? teams[0] : null
|
teams ? teams[0] : null
|
||||||
);
|
);
|
||||||
|
@ -76,127 +85,125 @@ const Team: React.FC<TeamProps> = ({
|
||||||
// store team info as {"team_id": team_info_object}
|
// store team info as {"team_id": team_info_object}
|
||||||
const [perTeamInfo, setPerTeamInfo] = useState<Record<string, any>>({});
|
const [perTeamInfo, setPerTeamInfo] = useState<Record<string, any>>({});
|
||||||
|
|
||||||
|
const EditTeamModal: React.FC<EditTeamModalProps> = ({
|
||||||
|
visible,
|
||||||
|
onCancel,
|
||||||
|
team,
|
||||||
|
onSubmit,
|
||||||
|
}) => {
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
|
||||||
const EditTeamModal: React.FC<EditTeamModalProps> = ({ visible, onCancel, team, onSubmit }) => {
|
const handleOk = () => {
|
||||||
const [form] = Form.useForm();
|
form
|
||||||
|
.validateFields()
|
||||||
|
.then((values) => {
|
||||||
|
const updatedValues = { ...values, team_id: team.team_id };
|
||||||
|
onSubmit(updatedValues);
|
||||||
|
form.resetFields();
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error("Validation failed:", error);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
const handleOk = () => {
|
return (
|
||||||
form
|
|
||||||
.validateFields()
|
|
||||||
.then((values) => {
|
|
||||||
const updatedValues = {...values, team_id: team.team_id};
|
|
||||||
onSubmit(updatedValues);
|
|
||||||
form.resetFields();
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error("Validation failed:", error);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Modal
|
<Modal
|
||||||
title="Edit Team"
|
title="Edit Team"
|
||||||
visible={visible}
|
visible={visible}
|
||||||
width={800}
|
width={800}
|
||||||
footer={null}
|
footer={null}
|
||||||
onOk={handleOk}
|
onOk={handleOk}
|
||||||
onCancel={onCancel}
|
onCancel={onCancel}
|
||||||
>
|
|
||||||
<Form
|
|
||||||
form={form}
|
|
||||||
onFinish={handleEditSubmit}
|
|
||||||
initialValues={team} // Pass initial values here
|
|
||||||
labelCol={{ span: 8 }}
|
|
||||||
wrapperCol={{ span: 16 }}
|
|
||||||
labelAlign="left"
|
|
||||||
>
|
>
|
||||||
<>
|
<Form
|
||||||
<Form.Item
|
form={form}
|
||||||
label="Team Name"
|
onFinish={handleEditSubmit}
|
||||||
name="team_alias"
|
initialValues={team} // Pass initial values here
|
||||||
rules={[{ required: true, message: 'Please input a team name' }]}
|
labelCol={{ span: 8 }}
|
||||||
>
|
wrapperCol={{ span: 16 }}
|
||||||
<Input />
|
labelAlign="left"
|
||||||
</Form.Item>
|
>
|
||||||
<Form.Item label="Models" name="models">
|
<>
|
||||||
<Select2
|
<Form.Item
|
||||||
mode="multiple"
|
label="Team Name"
|
||||||
placeholder="Select models"
|
name="team_alias"
|
||||||
style={{ width: "100%" }}
|
rules={[{ required: true, message: "Please input a team name" }]}
|
||||||
>
|
>
|
||||||
<Select2.Option key="all-proxy-models" value="all-proxy-models">
|
<Input />
|
||||||
{"All Proxy Models"}
|
</Form.Item>
|
||||||
</Select2.Option>
|
<Form.Item label="Models" name="models">
|
||||||
{userModels && userModels.map((model) => (
|
<Select2
|
||||||
<Select2.Option key={model} value={model}>
|
mode="multiple"
|
||||||
{model}
|
placeholder="Select models"
|
||||||
</Select2.Option>
|
style={{ width: "100%" }}
|
||||||
))}
|
>
|
||||||
|
<Select2.Option key="all-proxy-models" value="all-proxy-models">
|
||||||
</Select2>
|
{"All Proxy Models"}
|
||||||
</Form.Item>
|
</Select2.Option>
|
||||||
<Form.Item label="Max Budget (USD)" name="max_budget">
|
{userModels &&
|
||||||
<InputNumber step={0.01} precision={2} width={200} />
|
userModels.map((model) => (
|
||||||
</Form.Item>
|
<Select2.Option key={model} value={model}>
|
||||||
<Form.Item
|
{model}
|
||||||
label="Tokens per minute Limit (TPM)"
|
</Select2.Option>
|
||||||
name="tpm_limit"
|
))}
|
||||||
>
|
</Select2>
|
||||||
<InputNumber step={1} width={400} />
|
</Form.Item>
|
||||||
</Form.Item>
|
<Form.Item label="Max Budget (USD)" name="max_budget">
|
||||||
<Form.Item
|
<InputNumber step={0.01} precision={2} width={200} />
|
||||||
label="Requests per minute Limit (RPM)"
|
</Form.Item>
|
||||||
name="rpm_limit"
|
<Form.Item label="Tokens per minute Limit (TPM)" name="tpm_limit">
|
||||||
>
|
<InputNumber step={1} width={400} />
|
||||||
<InputNumber step={1} width={400} />
|
</Form.Item>
|
||||||
</Form.Item>
|
<Form.Item label="Requests per minute Limit (RPM)" name="rpm_limit">
|
||||||
<Form.Item
|
<InputNumber step={1} width={400} />
|
||||||
label="Requests per minute Limit (RPM)"
|
</Form.Item>
|
||||||
name="team_id"
|
<Form.Item
|
||||||
hidden={true}
|
label="Requests per minute Limit (RPM)"
|
||||||
></Form.Item>
|
name="team_id"
|
||||||
</>
|
hidden={true}
|
||||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
></Form.Item>
|
||||||
<Button2 htmlType="submit">Edit Team</Button2>
|
</>
|
||||||
</div>
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
</Form>
|
<Button2 htmlType="submit">Edit Team</Button2>
|
||||||
</Modal>
|
</div>
|
||||||
);
|
</Form>
|
||||||
};
|
</Modal>
|
||||||
|
|
||||||
const handleEditClick = (team: any) => {
|
|
||||||
setSelectedTeam(team);
|
|
||||||
setEditModalVisible(true);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleEditCancel = () => {
|
|
||||||
setEditModalVisible(false);
|
|
||||||
setSelectedTeam(null);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|
||||||
// Call API to update team with teamId and values
|
|
||||||
const teamId = formValues.team_id; // get team_id
|
|
||||||
|
|
||||||
console.log("handleEditSubmit:", formValues);
|
|
||||||
if (accessToken == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let newTeamValues = await teamUpdateCall(accessToken, formValues);
|
|
||||||
|
|
||||||
// Update the teams state with the updated team data
|
|
||||||
if (teams) {
|
|
||||||
const updatedTeams = teams.map((team) =>
|
|
||||||
team.team_id === teamId ? newTeamValues.data : team
|
|
||||||
);
|
);
|
||||||
setTeams(updatedTeams);
|
};
|
||||||
}
|
|
||||||
message.success("Team updated successfully");
|
|
||||||
|
|
||||||
setEditModalVisible(false);
|
const handleEditClick = (team: any) => {
|
||||||
setSelectedTeam(null);
|
setSelectedTeam(team);
|
||||||
};
|
setEditModalVisible(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEditCancel = () => {
|
||||||
|
setEditModalVisible(false);
|
||||||
|
setSelectedTeam(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
|
// Call API to update team with teamId and values
|
||||||
|
const teamId = formValues.team_id; // get team_id
|
||||||
|
|
||||||
|
console.log("handleEditSubmit:", formValues);
|
||||||
|
if (accessToken == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newTeamValues = await teamUpdateCall(accessToken, formValues);
|
||||||
|
|
||||||
|
// Update the teams state with the updated team data
|
||||||
|
if (teams) {
|
||||||
|
const updatedTeams = teams.map((team) =>
|
||||||
|
team.team_id === teamId ? newTeamValues.data : team
|
||||||
|
);
|
||||||
|
setTeams(updatedTeams);
|
||||||
|
}
|
||||||
|
message.success("Team updated successfully");
|
||||||
|
|
||||||
|
setEditModalVisible(false);
|
||||||
|
setSelectedTeam(null);
|
||||||
|
};
|
||||||
|
|
||||||
const handleOk = () => {
|
const handleOk = () => {
|
||||||
setIsTeamModalVisible(false);
|
setIsTeamModalVisible(false);
|
||||||
|
@ -224,9 +231,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
setIsDeleteModalOpen(true);
|
setIsDeleteModalOpen(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const confirmDelete = async () => {
|
const confirmDelete = async () => {
|
||||||
if (teamToDelete == null || teams == null || accessToken == null) {
|
if (teamToDelete == null || teams == null || accessToken == null) {
|
||||||
return;
|
return;
|
||||||
|
@ -235,7 +239,9 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
try {
|
try {
|
||||||
await teamDeleteCall(accessToken, teamToDelete);
|
await teamDeleteCall(accessToken, teamToDelete);
|
||||||
// Successfully completed the deletion. Update the state to trigger a rerender.
|
// Successfully completed the deletion. Update the state to trigger a rerender.
|
||||||
const filteredData = teams.filter((item) => item.team_id !== teamToDelete);
|
const filteredData = teams.filter(
|
||||||
|
(item) => item.team_id !== teamToDelete
|
||||||
|
);
|
||||||
setTeams(filteredData);
|
setTeams(filteredData);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error deleting the team:", error);
|
console.error("Error deleting the team:", error);
|
||||||
|
@ -253,8 +259,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
setTeamToDelete(null);
|
setTeamToDelete(null);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchUserModels = async () => {
|
const fetchUserModels = async () => {
|
||||||
try {
|
try {
|
||||||
|
@ -263,7 +267,11 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accessToken !== null) {
|
if (accessToken !== null) {
|
||||||
const model_available = await modelAvailableCall(accessToken, userID, userRole);
|
const model_available = await modelAvailableCall(
|
||||||
|
accessToken,
|
||||||
|
userID,
|
||||||
|
userRole
|
||||||
|
);
|
||||||
let available_model_names = model_available["data"].map(
|
let available_model_names = model_available["data"].map(
|
||||||
(element: { id: string }) => element.id
|
(element: { id: string }) => element.id
|
||||||
);
|
);
|
||||||
|
@ -275,7 +283,6 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
const fetchTeamInfo = async () => {
|
const fetchTeamInfo = async () => {
|
||||||
try {
|
try {
|
||||||
if (userID === null || userRole === null || accessToken === null) {
|
if (userID === null || userRole === null || accessToken === null) {
|
||||||
|
@ -288,22 +295,21 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
|
|
||||||
console.log("fetching team info:");
|
console.log("fetching team info:");
|
||||||
|
|
||||||
|
|
||||||
let _team_id_to_info: Record<string, any> = {};
|
let _team_id_to_info: Record<string, any> = {};
|
||||||
for (let i = 0; i < teams?.length; i++) {
|
for (let i = 0; i < teams?.length; i++) {
|
||||||
let _team_id = teams[i].team_id;
|
let _team_id = teams[i].team_id;
|
||||||
const teamInfo = await teamInfoCall(accessToken, _team_id);
|
const teamInfo = await teamInfoCall(accessToken, _team_id);
|
||||||
console.log("teamInfo response:", teamInfo);
|
console.log("teamInfo response:", teamInfo);
|
||||||
if (teamInfo !== null) {
|
if (teamInfo !== null) {
|
||||||
_team_id_to_info = {..._team_id_to_info, [_team_id]: teamInfo};
|
_team_id_to_info = { ..._team_id_to_info, [_team_id]: teamInfo };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
setPerTeamInfo(_team_id_to_info);
|
setPerTeamInfo(_team_id_to_info);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error fetching team info:", error);
|
console.error("Error fetching team info:", error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fetchUserModels();
|
fetchUserModels();
|
||||||
fetchTeamInfo();
|
fetchTeamInfo();
|
||||||
}, [accessToken, userID, userRole, teams]);
|
}, [accessToken, userID, userRole, teams]);
|
||||||
|
@ -311,6 +317,15 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
const handleCreate = async (formValues: Record<string, any>) => {
|
const handleCreate = async (formValues: Record<string, any>) => {
|
||||||
try {
|
try {
|
||||||
if (accessToken != null) {
|
if (accessToken != null) {
|
||||||
|
const newTeamAlias = formValues?.team_alias;
|
||||||
|
const existingTeamAliases = teams?.map((t) => t.team_alias) ?? [];
|
||||||
|
|
||||||
|
if (existingTeamAliases.includes(newTeamAlias)) {
|
||||||
|
throw new Error(
|
||||||
|
`Team alias ${newTeamAlias} already exists, please pick another alias`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
message.info("Creating Team");
|
message.info("Creating Team");
|
||||||
const response: any = await teamCreateCall(accessToken, formValues);
|
const response: any = await teamCreateCall(accessToken, formValues);
|
||||||
if (teams !== null) {
|
if (teams !== null) {
|
||||||
|
@ -364,7 +379,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
console.error("Error creating the team:", error);
|
console.error("Error creating the team:", error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
console.log(`received teams ${teams}`);
|
console.log(`received teams ${JSON.stringify(teams)}`);
|
||||||
return (
|
return (
|
||||||
<div className="w-full mx-4">
|
<div className="w-full mx-4">
|
||||||
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
|
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||||
|
@ -387,55 +402,124 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
{teams && teams.length > 0
|
{teams && teams.length > 0
|
||||||
? teams.map((team: any) => (
|
? teams.map((team: any) => (
|
||||||
<TableRow key={team.team_id}>
|
<TableRow key={team.team_id}>
|
||||||
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>{team["team_alias"]}</TableCell>
|
<TableCell
|
||||||
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>{team["spend"]}</TableCell>
|
style={{
|
||||||
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>
|
maxWidth: "4px",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
overflow: "hidden",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{team["team_alias"]}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell
|
||||||
|
style={{
|
||||||
|
maxWidth: "4px",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
overflow: "hidden",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{team["spend"]}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell
|
||||||
|
style={{
|
||||||
|
maxWidth: "4px",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
overflow: "hidden",
|
||||||
|
}}
|
||||||
|
>
|
||||||
{team["max_budget"] ? team["max_budget"] : "No limit"}
|
{team["max_budget"] ? team["max_budget"] : "No limit"}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell style={{ maxWidth: "8-x", whiteSpace: "pre-wrap", overflow: "hidden" }}>
|
<TableCell
|
||||||
|
style={{
|
||||||
|
maxWidth: "8-x",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
overflow: "hidden",
|
||||||
|
}}
|
||||||
|
>
|
||||||
{Array.isArray(team.models) ? (
|
{Array.isArray(team.models) ? (
|
||||||
<div style={{ display: "flex", flexDirection: "column" }}>
|
<div
|
||||||
|
style={{
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
}}
|
||||||
|
>
|
||||||
{team.models.length === 0 ? (
|
{team.models.length === 0 ? (
|
||||||
<Badge size={"xs"} className="mb-1" color="red">
|
<Badge size={"xs"} className="mb-1" color="red">
|
||||||
<Text>All Proxy Models</Text>
|
<Text>All Proxy Models</Text>
|
||||||
</Badge>
|
</Badge>
|
||||||
) : (
|
) : (
|
||||||
team.models.map((model: string, index: number) => (
|
team.models.map(
|
||||||
model === "all-proxy-models" ? (
|
(model: string, index: number) =>
|
||||||
<Badge key={index} size={"xs"} className="mb-1" color="red">
|
model === "all-proxy-models" ? (
|
||||||
<Text>All Proxy Models</Text>
|
<Badge
|
||||||
</Badge>
|
key={index}
|
||||||
) : (
|
size={"xs"}
|
||||||
<Badge key={index} size={"xs"} className="mb-1" color="blue">
|
className="mb-1"
|
||||||
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text>
|
color="red"
|
||||||
</Badge>
|
>
|
||||||
)
|
<Text>All Proxy Models</Text>
|
||||||
))
|
</Badge>
|
||||||
|
) : (
|
||||||
|
<Badge
|
||||||
|
key={index}
|
||||||
|
size={"xs"}
|
||||||
|
className="mb-1"
|
||||||
|
color="blue"
|
||||||
|
>
|
||||||
|
<Text>
|
||||||
|
{model.length > 30
|
||||||
|
? `${model.slice(0, 30)}...`
|
||||||
|
: model}
|
||||||
|
</Text>
|
||||||
|
</Badge>
|
||||||
|
)
|
||||||
|
)
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
) : null}
|
) : null}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
|
||||||
|
|
||||||
<TableCell style={{ maxWidth: "4px", whiteSpace: "pre-wrap", overflow: "hidden" }}>
|
<TableCell
|
||||||
|
style={{
|
||||||
|
maxWidth: "4px",
|
||||||
|
whiteSpace: "pre-wrap",
|
||||||
|
overflow: "hidden",
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Text>
|
<Text>
|
||||||
TPM:{" "}
|
TPM: {team.tpm_limit ? team.tpm_limit : "Unlimited"}{" "}
|
||||||
{team.tpm_limit ? team.tpm_limit : "Unlimited"}{" "}
|
|
||||||
<br></br>RPM:{" "}
|
<br></br>RPM:{" "}
|
||||||
{team.rpm_limit ? team.rpm_limit : "Unlimited"}
|
{team.rpm_limit ? team.rpm_limit : "Unlimited"}
|
||||||
</Text>
|
</Text>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Text>{perTeamInfo && team.team_id && perTeamInfo[team.team_id] && perTeamInfo[team.team_id].keys && perTeamInfo[team.team_id].keys.length} Keys</Text>
|
<Text>
|
||||||
<Text>{perTeamInfo && team.team_id && perTeamInfo[team.team_id] && perTeamInfo[team.team_id].team_info && perTeamInfo[team.team_id].team_info.members_with_roles && perTeamInfo[team.team_id].team_info.members_with_roles.length} Members</Text>
|
{perTeamInfo &&
|
||||||
|
team.team_id &&
|
||||||
|
perTeamInfo[team.team_id] &&
|
||||||
|
perTeamInfo[team.team_id].keys &&
|
||||||
|
perTeamInfo[team.team_id].keys.length}{" "}
|
||||||
|
Keys
|
||||||
|
</Text>
|
||||||
|
<Text>
|
||||||
|
{perTeamInfo &&
|
||||||
|
team.team_id &&
|
||||||
|
perTeamInfo[team.team_id] &&
|
||||||
|
perTeamInfo[team.team_id].team_info &&
|
||||||
|
perTeamInfo[team.team_id].team_info
|
||||||
|
.members_with_roles &&
|
||||||
|
perTeamInfo[team.team_id].team_info
|
||||||
|
.members_with_roles.length}{" "}
|
||||||
|
Members
|
||||||
|
</Text>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Icon
|
<Icon
|
||||||
icon={PencilAltIcon}
|
icon={PencilAltIcon}
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={() => handleEditClick(team)}
|
onClick={() => handleEditClick(team)}
|
||||||
/>
|
/>
|
||||||
<Icon
|
<Icon
|
||||||
onClick={() => handleDelete(team.team_id)}
|
onClick={() => handleDelete(team.team_id)}
|
||||||
icon={TrashIcon}
|
icon={TrashIcon}
|
||||||
size="sm"
|
size="sm"
|
||||||
|
@ -481,7 +565,11 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse">
|
<div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse">
|
||||||
<Button onClick={confirmDelete} color="red" className="ml-2">
|
<Button
|
||||||
|
onClick={confirmDelete}
|
||||||
|
color="red"
|
||||||
|
className="ml-2"
|
||||||
|
>
|
||||||
Delete
|
Delete
|
||||||
</Button>
|
</Button>
|
||||||
<Button onClick={cancelDelete}>Cancel</Button>
|
<Button onClick={cancelDelete}>Cancel</Button>
|
||||||
|
@ -515,10 +603,12 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
labelAlign="left"
|
labelAlign="left"
|
||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="Team Name"
|
label="Team Name"
|
||||||
name="team_alias"
|
name="team_alias"
|
||||||
rules={[{ required: true, message: 'Please input a team name' }]}
|
rules={[
|
||||||
|
{ required: true, message: "Please input a team name" },
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<TextInput placeholder="" />
|
<TextInput placeholder="" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
@ -528,7 +618,10 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
placeholder="Select models"
|
placeholder="Select models"
|
||||||
style={{ width: "100%" }}
|
style={{ width: "100%" }}
|
||||||
>
|
>
|
||||||
<Select2.Option key="all-proxy-models" value="all-proxy-models">
|
<Select2.Option
|
||||||
|
key="all-proxy-models"
|
||||||
|
value="all-proxy-models"
|
||||||
|
>
|
||||||
All Proxy Models
|
All Proxy Models
|
||||||
</Select2.Option>
|
</Select2.Option>
|
||||||
{userModels.map((model) => (
|
{userModels.map((model) => (
|
||||||
|
@ -606,8 +699,8 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
{member["user_email"]
|
{member["user_email"]
|
||||||
? member["user_email"]
|
? member["user_email"]
|
||||||
: member["user_id"]
|
: member["user_id"]
|
||||||
? member["user_id"]
|
? member["user_id"]
|
||||||
: null}
|
: null}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>{member["role"]}</TableCell>
|
<TableCell>{member["role"]}</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
@ -618,13 +711,13 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
||||||
</Table>
|
</Table>
|
||||||
</Card>
|
</Card>
|
||||||
{selectedTeam && (
|
{selectedTeam && (
|
||||||
<EditTeamModal
|
<EditTeamModal
|
||||||
visible={editModalVisible}
|
visible={editModalVisible}
|
||||||
onCancel={handleEditCancel}
|
onCancel={handleEditCancel}
|
||||||
team={selectedTeam}
|
team={selectedTeam}
|
||||||
onSubmit={handleEditSubmit}
|
onSubmit={handleEditSubmit}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Col>
|
</Col>
|
||||||
<Col numColSpan={1}>
|
<Col numColSpan={1}>
|
||||||
<Button
|
<Button
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue