forked from phoenix/litellm-mirror
Merge branch 'BerriAI:main' into add_helm_chart
This commit is contained in:
commit
8876b47e94
38 changed files with 860 additions and 237 deletions
|
@ -156,7 +156,8 @@ jobs:
|
||||||
--config /app/config.yaml \
|
--config /app/config.yaml \
|
||||||
--port 4000 \
|
--port 4000 \
|
||||||
--num_workers 8 \
|
--num_workers 8 \
|
||||||
--debug
|
--debug \
|
||||||
|
--run_gunicorn \
|
||||||
- run:
|
- run:
|
||||||
name: Install curl and dockerize
|
name: Install curl and dockerize
|
||||||
command: |
|
command: |
|
||||||
|
|
|
@ -52,4 +52,4 @@ RUN chmod +x entrypoint.sh
|
||||||
EXPOSE 4000/tcp
|
EXPOSE 4000/tcp
|
||||||
|
|
||||||
ENTRYPOINT ["litellm"]
|
ENTRYPOINT ["litellm"]
|
||||||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml", "--detailed_debug"]
|
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml", "--detailed_debug", "--run_gunicorn"]
|
|
@ -56,4 +56,4 @@ EXPOSE 4000/tcp
|
||||||
# # Set your entrypoint and command
|
# # Set your entrypoint and command
|
||||||
|
|
||||||
ENTRYPOINT ["litellm"]
|
ENTRYPOINT ["litellm"]
|
||||||
CMD ["--port", "4000"]
|
CMD ["--port", "4000", "--run_gunicorn"]
|
||||||
|
|
|
@ -188,7 +188,7 @@ print(response)
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
|
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
|
||||||
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
|
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
|
||||||
|
|
||||||
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
|
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
|
||||||
|
@ -210,6 +210,12 @@ model_list:
|
||||||
api_key: sk-123
|
api_key: sk-123
|
||||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
||||||
temperature: 0.2
|
temperature: 0.2
|
||||||
|
- model_name: openai-gpt-3.5
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-3.5-turbo
|
||||||
|
api_key: sk-123
|
||||||
|
organization: org-ikDc4ex8NB
|
||||||
|
temperature: 0.2
|
||||||
- model_name: mistral-7b
|
- model_name: mistral-7b
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: ollama/mistral
|
model: ollama/mistral
|
||||||
|
@ -483,3 +489,55 @@ general_settings:
|
||||||
max_parallel_requests: 100 # max parallel requests for a user = 100
|
max_parallel_requests: 100 # max parallel requests for a user = 100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## All settings
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"environment_variables": {},
|
||||||
|
"model_list": [
|
||||||
|
{
|
||||||
|
"model_name": "string",
|
||||||
|
"litellm_params": {},
|
||||||
|
"model_info": {
|
||||||
|
"id": "string",
|
||||||
|
"mode": "embedding",
|
||||||
|
"input_cost_per_token": 0,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"base_model": "gpt-4-1106-preview",
|
||||||
|
"additionalProp1": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"litellm_settings": {}, # ALL (https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)
|
||||||
|
"general_settings": {
|
||||||
|
"completion_model": "string",
|
||||||
|
"key_management_system": "google_kms", # either google_kms or azure_kms
|
||||||
|
"master_key": "string",
|
||||||
|
"database_url": "string",
|
||||||
|
"database_type": "dynamo_db",
|
||||||
|
"database_args": {
|
||||||
|
"billing_mode": "PROVISIONED_THROUGHPUT",
|
||||||
|
"read_capacity_units": 0,
|
||||||
|
"write_capacity_units": 0,
|
||||||
|
"ssl_verify": true,
|
||||||
|
"region_name": "string",
|
||||||
|
"user_table_name": "LiteLLM_UserTable",
|
||||||
|
"key_table_name": "LiteLLM_VerificationToken",
|
||||||
|
"config_table_name": "LiteLLM_Config",
|
||||||
|
"spend_table_name": "LiteLLM_SpendLogs"
|
||||||
|
},
|
||||||
|
"otel": true,
|
||||||
|
"custom_auth": "string",
|
||||||
|
"max_parallel_requests": 0,
|
||||||
|
"infer_model_from_keys": true,
|
||||||
|
"background_health_checks": true,
|
||||||
|
"health_check_interval": 300,
|
||||||
|
"alerting": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
|
"alerting_threshold": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
34
docs/my-website/docs/proxy/debugging.md
Normal file
34
docs/my-website/docs/proxy/debugging.md
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Debugging
|
||||||
|
|
||||||
|
2 levels of debugging supported.
|
||||||
|
|
||||||
|
- debug (prints info logs)
|
||||||
|
- detailed debug (prints debug logs)
|
||||||
|
|
||||||
|
## `debug`
|
||||||
|
|
||||||
|
**via cli**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
**via env**
|
||||||
|
|
||||||
|
```python
|
||||||
|
os.environ["LITELLM_LOG"] = "INFO"
|
||||||
|
```
|
||||||
|
|
||||||
|
## `detailed debug`
|
||||||
|
|
||||||
|
**via cli**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
**via env**
|
||||||
|
|
||||||
|
```python
|
||||||
|
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||||
|
```
|
|
@ -112,7 +112,8 @@ Example Response:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"db": "connected"
|
"db": "connected",
|
||||||
|
"litellm_version":"1.19.2",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -121,7 +122,8 @@ Example Response:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"db": "Not connected"
|
"db": "Not connected",
|
||||||
|
"litellm_version":"1.19.2",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# [BETA] Admin UI
|
# [BETA] Admin UI
|
||||||
|
|
||||||
|
@ -26,22 +28,27 @@ general_settings:
|
||||||
allow_user_auth: true
|
allow_user_auth: true
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Setup Google SSO - Use this to Authenticate Team Members to the UI
|
## 2. Setup SSO/Auth for UI
|
||||||
- Create an Oauth 2.0 Client
|
|
||||||
<Image img={require('../../img/google_oauth2.png')} />
|
|
||||||
|
|
||||||
- Navigate to Google `Credenentials`
|
<Tabs>
|
||||||
- Create a new Oauth client ID
|
<TabItem value="username" label="Quick Start - Username, Password">
|
||||||
- Set the `GOOGLE_CLIENT_ID` and `GOOGLE_CLIENT_SECRET` in your Proxy .env
|
|
||||||
- Set Redirect URL on your Oauth 2.0 Client
|
|
||||||
- Click on your Oauth 2.0 client on https://console.cloud.google.com/
|
|
||||||
- Set a redirect url = `<your proxy base url>/google-callback`
|
|
||||||
```
|
|
||||||
https://litellm-production-7002.up.railway.app/google-callback
|
|
||||||
```
|
|
||||||
<Image img={require('../../img/google_redirect.png')} />
|
|
||||||
## 3. Required env variables on your Proxy
|
|
||||||
|
|
||||||
|
Set the following in your .env on the Proxy
|
||||||
|
|
||||||
|
```shell
|
||||||
|
UI_USERNAME=ishaan-litellm
|
||||||
|
UI_PASSWORD=langchain
|
||||||
|
```
|
||||||
|
|
||||||
|
On accessing the LiteLLM UI, you will be prompted to enter your username, password
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="google" label="Google SSO">
|
||||||
|
|
||||||
|
- Create a new Oauth 2.0 Client on https://console.cloud.google.com/
|
||||||
|
|
||||||
|
**Required .env variables on your Proxy**
|
||||||
```shell
|
```shell
|
||||||
PROXY_BASE_URL="<your deployed proxy endpoint>" example PROXY_BASE_URL=https://litellm-production-7002.up.railway.app/
|
PROXY_BASE_URL="<your deployed proxy endpoint>" example PROXY_BASE_URL=https://litellm-production-7002.up.railway.app/
|
||||||
|
|
||||||
|
@ -50,6 +57,37 @@ GOOGLE_CLIENT_ID=
|
||||||
GOOGLE_CLIENT_SECRET=
|
GOOGLE_CLIENT_SECRET=
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Set Redirect URL on your Oauth 2.0 Client on https://console.cloud.google.com/
|
||||||
|
- Set a redirect url = `<your proxy base url>/sso/callback`
|
||||||
|
```shell
|
||||||
|
https://litellm-production-7002.up.railway.app/sso/callback
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="msft" label="Microsoft SSO">
|
||||||
|
|
||||||
|
- Create a new App Registration on https://portal.azure.com/
|
||||||
|
- Create a client Secret for your App Registration
|
||||||
|
|
||||||
|
**Required .env variables on your Proxy**
|
||||||
|
```shell
|
||||||
|
PROXY_BASE_URL="<your deployed proxy endpoint>" example PROXY_BASE_URL=https://litellm-production-7002.up.railway.app/
|
||||||
|
|
||||||
|
MICROSOFT_CLIENT_ID="84583a4d-"
|
||||||
|
MICROSOFT_CLIENT_SECRET="nbk8Q~"
|
||||||
|
MICROSOFT_TENANT="5a39737
|
||||||
|
```
|
||||||
|
- Set Redirect URI on your App Registration on https://portal.azure.com/
|
||||||
|
- Set a redirect url = `<your proxy base url>/sso/callback`
|
||||||
|
```shell
|
||||||
|
http://localhost:4000/sso/callback
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## 4. Use UI
|
## 4. Use UI
|
||||||
|
|
||||||
👉 Get Started here: https://litellm-dashboard.vercel.app/
|
👉 Get Started here: https://litellm-dashboard.vercel.app/
|
||||||
|
|
|
@ -278,6 +278,21 @@ Request Params:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Default /key/generate params
|
||||||
|
Use this, if you need to control the default `max_budget` or any `key/generate` param per key.
|
||||||
|
|
||||||
|
When a `/key/generate` request does not specify `max_budget`, it will use the `max_budget` specified in `default_key_generate_params`
|
||||||
|
|
||||||
|
Set `litellm_settings:default_key_generate_params`:
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
default_key_generate_params:
|
||||||
|
max_budget: 1.5000
|
||||||
|
models: ["azure-gpt-3.5"]
|
||||||
|
duration: # blank means `null`
|
||||||
|
metadata: {"setting":"default"}
|
||||||
|
team_id: "core-infra"
|
||||||
|
```
|
||||||
## Set Budgets - Per Key
|
## Set Budgets - Per Key
|
||||||
|
|
||||||
Set `max_budget` in (USD $) param in the `key/generate` request. By default the `max_budget` is set to `null` and is not checked for keys
|
Set `max_budget` in (USD $) param in the `key/generate` request. By default the `max_budget` is set to `null` and is not checked for keys
|
||||||
|
|
|
@ -115,6 +115,7 @@ const sidebars = {
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/model_management",
|
"proxy/model_management",
|
||||||
"proxy/health",
|
"proxy/health",
|
||||||
|
"proxy/debugging",
|
||||||
{
|
{
|
||||||
"type": "category",
|
"type": "category",
|
||||||
"label": "🔥 Load Balancing",
|
"label": "🔥 Load Balancing",
|
||||||
|
|
|
@ -143,6 +143,7 @@ model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/mai
|
||||||
suppress_debug_info = False
|
suppress_debug_info = False
|
||||||
dynamodb_table_name: Optional[str] = None
|
dynamodb_table_name: Optional[str] = None
|
||||||
s3_callback_params: Optional[Dict] = None
|
s3_callback_params: Optional[Dict] = None
|
||||||
|
default_key_generate_params: Optional[Dict] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
request_timeout: Optional[float] = 6000
|
request_timeout: Optional[float] = 6000
|
||||||
num_retries: Optional[int] = None # per model endpoint
|
num_retries: Optional[int] = None # per model endpoint
|
||||||
|
|
|
@ -702,6 +702,11 @@ def _embedding_func_single(
|
||||||
encoding=None,
|
encoding=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
):
|
):
|
||||||
|
if isinstance(input, str) is False:
|
||||||
|
raise BedrockError(
|
||||||
|
message="Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
## FORMAT EMBEDDING INPUT ##
|
## FORMAT EMBEDDING INPUT ##
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
|
@ -795,7 +800,8 @@ def embedding(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
if type(input) == str:
|
if isinstance(input, str):
|
||||||
|
## Embedding Call
|
||||||
embeddings = [
|
embeddings = [
|
||||||
_embedding_func_single(
|
_embedding_func_single(
|
||||||
model,
|
model,
|
||||||
|
@ -805,8 +811,8 @@ def embedding(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
elif isinstance(input, list):
|
||||||
## Embedding Call
|
## Embedding Call - assuming this is a List[str]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
_embedding_func_single(
|
_embedding_func_single(
|
||||||
model,
|
model,
|
||||||
|
@ -817,6 +823,12 @@ def embedding(
|
||||||
)
|
)
|
||||||
for i in input
|
for i in input
|
||||||
] # [TODO]: make these parallel calls
|
] # [TODO]: make these parallel calls
|
||||||
|
else:
|
||||||
|
# enters this branch if input = int, ex. input=2
|
||||||
|
raise BedrockError(
|
||||||
|
message="Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
|
|
@ -145,8 +145,8 @@ def get_ollama_response(
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
optional_params["stream"] = optional_params.get("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
data = {"model": model, "messages": messages, **optional_params}
|
data = {"model": model, "messages": messages, "options": optional_params}
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=None,
|
input=None,
|
||||||
|
@ -159,7 +159,7 @@ def get_ollama_response(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False) == True:
|
if stream == True:
|
||||||
response = ollama_async_streaming(
|
response = ollama_async_streaming(
|
||||||
url=url,
|
url=url,
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -176,7 +176,7 @@ def get_ollama_response(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
elif optional_params.get("stream", False) == True:
|
elif stream == True:
|
||||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|
|
@ -221,6 +221,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
custom_prompt_dict: dict = {},
|
custom_prompt_dict: dict = {},
|
||||||
client=None,
|
client=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
|
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(
|
return self.acompletion(
|
||||||
|
@ -266,6 +268,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
|
@ -278,6 +281,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
|
@ -291,6 +295,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
http_client=litellm.client_session,
|
http_client=litellm.client_session,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
@ -358,6 +363,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
timeout: float,
|
timeout: float,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
|
@ -372,6 +378,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
http_client=litellm.aclient_session,
|
http_client=litellm.aclient_session,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
@ -412,6 +419,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
|
@ -423,6 +431,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
http_client=litellm.client_session,
|
http_client=litellm.client_session,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
@ -454,6 +463,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
|
@ -467,6 +477,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
http_client=litellm.aclient_session,
|
http_client=litellm.aclient_session,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
@ -748,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
):
|
):
|
||||||
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
|
client = AsyncOpenAI(
|
||||||
|
api_key=api_key, timeout=timeout, organization=organization
|
||||||
|
)
|
||||||
if model is None and mode != "image_generation":
|
if model is None and mode != "image_generation":
|
||||||
raise Exception("model is not set")
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
|
|
@ -104,8 +104,8 @@ def mistral_instruct_pt(messages):
|
||||||
initial_prompt_value="<s>",
|
initial_prompt_value="<s>",
|
||||||
role_dict={
|
role_dict={
|
||||||
"system": {
|
"system": {
|
||||||
"pre_message": "[INST] <<SYS>>\n",
|
"pre_message": "[INST] \n",
|
||||||
"post_message": "<</SYS>> [/INST]\n",
|
"post_message": " [/INST]\n",
|
||||||
},
|
},
|
||||||
"user": {"pre_message": "[INST] ", "post_message": " [/INST]\n"},
|
"user": {"pre_message": "[INST] ", "post_message": " [/INST]\n"},
|
||||||
"assistant": {"pre_message": " ", "post_message": " "},
|
"assistant": {"pre_message": " ", "post_message": " "},
|
||||||
|
@ -376,6 +376,7 @@ def anthropic_pt(
|
||||||
You can "put words in Claude's mouth" by ending with an assistant message.
|
You can "put words in Claude's mouth" by ending with an assistant message.
|
||||||
See: https://docs.anthropic.com/claude/docs/put-words-in-claudes-mouth
|
See: https://docs.anthropic.com/claude/docs/put-words-in-claudes-mouth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
HUMAN_PROMPT = "\n\nHuman: "
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
AI_PROMPT = "\n\nAssistant: "
|
AI_PROMPT = "\n\nAssistant: "
|
||||||
|
@ -403,27 +404,30 @@ def _load_image_from_url(image_url):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except:
|
||||||
raise Exception("gemini image conversion failed please run `pip install Pillow`")
|
raise Exception(
|
||||||
|
"gemini image conversion failed please run `pip install Pillow`"
|
||||||
|
)
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send a GET request to the image URL
|
# Send a GET request to the image URL
|
||||||
response = requests.get(image_url)
|
response = requests.get(image_url)
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
# Check the response's content type to ensure it is an image
|
# Check the response's content type to ensure it is an image
|
||||||
content_type = response.headers.get('content-type')
|
content_type = response.headers.get("content-type")
|
||||||
if not content_type or 'image' not in content_type:
|
if not content_type or "image" not in content_type:
|
||||||
raise ValueError(f"URL does not point to a valid image (content-type: {content_type})")
|
raise ValueError(
|
||||||
|
f"URL does not point to a valid image (content-type: {content_type})"
|
||||||
|
)
|
||||||
|
|
||||||
# Load the image from the response content
|
# Load the image from the response content
|
||||||
return Image.open(BytesIO(response.content))
|
return Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
print(f"Request failed: {e}")
|
raise Exception(f"Request failed: {e}")
|
||||||
except UnidentifiedImageError:
|
except Exception as e:
|
||||||
print("Cannot identify image file (it may not be a supported image format or might be corrupted).")
|
raise e
|
||||||
except ValueError as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
def _gemini_vision_convert_messages(messages: list):
|
def _gemini_vision_convert_messages(messages: list):
|
||||||
|
@ -441,10 +445,11 @@ def _gemini_vision_convert_messages(messages: list):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except:
|
||||||
raise Exception("gemini image conversion failed please run `pip install Pillow`")
|
raise Exception(
|
||||||
|
"gemini image conversion failed please run `pip install Pillow`"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# given messages for gpt-4 vision, convert them for gemini
|
# given messages for gpt-4 vision, convert them for gemini
|
||||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
@ -593,7 +598,7 @@ def prompt_factory(
|
||||||
if custom_llm_provider == "ollama":
|
if custom_llm_provider == "ollama":
|
||||||
return ollama_pt(model=model, messages=messages)
|
return ollama_pt(model=model, messages=messages)
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
if any(_ in model for _ in ["claude-2.1","claude-v2:1"]):
|
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
|
||||||
return claude_2_1_pt(messages=messages)
|
return claude_2_1_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return anthropic_pt(messages=messages)
|
return anthropic_pt(messages=messages)
|
||||||
|
|
|
@ -237,8 +237,11 @@ def completion(
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
||||||
|
import google.auth
|
||||||
|
|
||||||
vertexai.init(project=vertex_project, location=vertex_location)
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
|
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||||
|
vertexai.init(project=vertex_project, location=vertex_location, credentials=creds)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.VertexAIConfig.get_config()
|
config = litellm.VertexAIConfig.get_config()
|
||||||
|
|
|
@ -450,6 +450,7 @@ def completion(
|
||||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||||
max_retries = kwargs.get("max_retries", None)
|
max_retries = kwargs.get("max_retries", None)
|
||||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||||
|
organization = kwargs.get("organization", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
@ -591,26 +592,37 @@ def completion(
|
||||||
|
|
||||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
|
print_verbose(f"Registering model={model} in model cost map")
|
||||||
litellm.register_model(
|
litellm.register_model(
|
||||||
{
|
{
|
||||||
|
f"{custom_llm_provider}/{model}": {
|
||||||
|
"input_cost_per_token": input_cost_per_token,
|
||||||
|
"output_cost_per_token": output_cost_per_token,
|
||||||
|
"litellm_provider": custom_llm_provider,
|
||||||
|
},
|
||||||
model: {
|
model: {
|
||||||
"input_cost_per_token": input_cost_per_token,
|
"input_cost_per_token": input_cost_per_token,
|
||||||
"output_cost_per_token": output_cost_per_token,
|
"output_cost_per_token": output_cost_per_token,
|
||||||
"litellm_provider": custom_llm_provider,
|
"litellm_provider": custom_llm_provider,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if (
|
elif (
|
||||||
input_cost_per_second is not None
|
input_cost_per_second is not None
|
||||||
): # time based pricing just needs cost in place
|
): # time based pricing just needs cost in place
|
||||||
output_cost_per_second = output_cost_per_second or 0.0
|
output_cost_per_second = output_cost_per_second or 0.0
|
||||||
litellm.register_model(
|
litellm.register_model(
|
||||||
{
|
{
|
||||||
|
f"{custom_llm_provider}/{model}": {
|
||||||
|
"input_cost_per_second": input_cost_per_second,
|
||||||
|
"output_cost_per_second": output_cost_per_second,
|
||||||
|
"litellm_provider": custom_llm_provider,
|
||||||
|
},
|
||||||
model: {
|
model: {
|
||||||
"input_cost_per_second": input_cost_per_second,
|
"input_cost_per_second": input_cost_per_second,
|
||||||
"output_cost_per_second": output_cost_per_second,
|
"output_cost_per_second": output_cost_per_second,
|
||||||
"litellm_provider": custom_llm_provider,
|
"litellm_provider": custom_llm_provider,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
||||||
|
@ -787,7 +799,8 @@ def completion(
|
||||||
or "https://api.openai.com/v1"
|
or "https://api.openai.com/v1"
|
||||||
)
|
)
|
||||||
openai.organization = (
|
openai.organization = (
|
||||||
litellm.organization
|
organization
|
||||||
|
or litellm.organization
|
||||||
or get_secret("OPENAI_ORGANIZATION")
|
or get_secret("OPENAI_ORGANIZATION")
|
||||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
)
|
)
|
||||||
|
@ -827,6 +840,7 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
|
@ -3224,6 +3238,7 @@ async def ahealth_check(
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
):
|
):
|
||||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||||
|
organization = model_params.get("organization")
|
||||||
|
|
||||||
timeout = (
|
timeout = (
|
||||||
model_params.get("timeout")
|
model_params.get("timeout")
|
||||||
|
@ -3241,6 +3256,7 @@ async def ahealth_check(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input=input,
|
input=input,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
|
@ -3265,6 +3281,7 @@ async def ahealth_check(
|
||||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
|
verbose_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -233,10 +233,6 @@ class DynamoDBWrapper(CustomDB):
|
||||||
table = client.table(self.database_arguments.config_table_name)
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
key_name = "param_name"
|
key_name = "param_name"
|
||||||
|
|
||||||
if key_name == "token" and key.startswith("sk-"):
|
|
||||||
# ensure it's hashed
|
|
||||||
key = hash_token(token=key)
|
|
||||||
|
|
||||||
response = await table.get_item({key_name: key})
|
response = await table.get_item({key_name: key})
|
||||||
|
|
||||||
new_response: Any = None
|
new_response: Any = None
|
||||||
|
|
|
@ -17,7 +17,12 @@ class MaxParallelRequestsHandler(CustomLogger):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
def print_verbose(self, print_statement):
|
||||||
|
try:
|
||||||
verbose_proxy_logger.debug(print_statement)
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
if litellm.set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -157,6 +157,12 @@ def is_port_in_use(port):
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of requests to hit async endpoint with",
|
help="Number of requests to hit async endpoint with",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--run_gunicorn",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Starts proxy via gunicorn, instead of uvicorn (better for managing multiple workers)",
|
||||||
|
)
|
||||||
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
||||||
def run_server(
|
def run_server(
|
||||||
host,
|
host,
|
||||||
|
@ -186,6 +192,7 @@ def run_server(
|
||||||
use_queue,
|
use_queue,
|
||||||
health,
|
health,
|
||||||
version,
|
version,
|
||||||
|
run_gunicorn,
|
||||||
):
|
):
|
||||||
global feature_telemetry
|
global feature_telemetry
|
||||||
args = locals()
|
args = locals()
|
||||||
|
@ -439,9 +446,9 @@ def run_server(
|
||||||
port = random.randint(1024, 49152)
|
port = random.randint(1024, 49152)
|
||||||
from litellm.proxy.proxy_server import app
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
if os.name == "nt":
|
if run_gunicorn == False:
|
||||||
uvicorn.run(app, host=host, port=port) # run uvicorn
|
uvicorn.run(app, host=host, port=port) # run uvicorn
|
||||||
else:
|
elif run_gunicorn == True:
|
||||||
import gunicorn.app.base
|
import gunicorn.app.base
|
||||||
|
|
||||||
# Gunicorn Application Class
|
# Gunicorn Application Class
|
||||||
|
|
|
@ -69,11 +69,16 @@ litellm_settings:
|
||||||
success_callback: ['langfuse']
|
success_callback: ['langfuse']
|
||||||
max_budget: 10 # global budget for proxy
|
max_budget: 10 # global budget for proxy
|
||||||
budget_duration: 30d # global budget duration, will reset after 30d
|
budget_duration: 30d # global budget duration, will reset after 30d
|
||||||
|
default_key_generate_params:
|
||||||
|
max_budget: 1.5000
|
||||||
|
models: ["azure-gpt-3.5"]
|
||||||
|
duration: None
|
||||||
# cache: True
|
# cache: True
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
|
allow_user_auth: True
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
alerting_threshold: 10 # sends alerts if requests hang for 2 seconds
|
alerting_threshold: 10 # sends alerts if requests hang for 2 seconds
|
||||||
|
|
|
@ -76,6 +76,7 @@ from litellm.proxy.utils import (
|
||||||
get_logging_payload,
|
get_logging_payload,
|
||||||
reset_budget,
|
reset_budget,
|
||||||
hash_token,
|
hash_token,
|
||||||
|
html_form,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -94,6 +95,7 @@ from fastapi import (
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
Header,
|
Header,
|
||||||
Response,
|
Response,
|
||||||
|
Form,
|
||||||
)
|
)
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
@ -245,8 +247,6 @@ async def user_api_key_auth(
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
if isinstance(api_key, str):
|
|
||||||
assert api_key.startswith("sk-") # prevent token hashes from being used
|
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(api_key=api_key)
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
|
@ -283,6 +283,10 @@ async def user_api_key_auth(
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return UserAPIKeyAuth(api_key=master_key)
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
api_key, str
|
||||||
|
): # if generated token, make sure it starts with sk-.
|
||||||
|
assert api_key.startswith("sk-") # prevent token hashes from being used
|
||||||
if route.startswith("/config/") and not is_master_key_valid:
|
if route.startswith("/config/") and not is_master_key_valid:
|
||||||
raise Exception(f"Only admin can modify config")
|
raise Exception(f"Only admin can modify config")
|
||||||
|
|
||||||
|
@ -292,6 +296,7 @@ async def user_api_key_auth(
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
|
||||||
## check for cache hit (In-Memory Cache)
|
## check for cache hit (In-Memory Cache)
|
||||||
|
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
||||||
if api_key.startswith("sk-"):
|
if api_key.startswith("sk-"):
|
||||||
api_key = hash_token(token=api_key)
|
api_key = hash_token(token=api_key)
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
|
@ -304,10 +309,15 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
|
try:
|
||||||
valid_token = await custom_db_client.get_data(
|
valid_token = await custom_db_client.get_data(
|
||||||
key=api_key, table_name="key"
|
key=api_key, table_name="key"
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
# (Patch: For DynamoDB Backwards Compatibility)
|
||||||
|
valid_token = await custom_db_client.get_data(
|
||||||
|
key=original_api_key, table_name="key"
|
||||||
|
)
|
||||||
verbose_proxy_logger.debug(f"Token from db: {valid_token}")
|
verbose_proxy_logger.debug(f"Token from db: {valid_token}")
|
||||||
elif valid_token is not None:
|
elif valid_token is not None:
|
||||||
verbose_proxy_logger.debug(f"API Key Cache Hit!")
|
verbose_proxy_logger.debug(f"API Key Cache Hit!")
|
||||||
|
@ -1117,6 +1127,9 @@ class ProxyConfig:
|
||||||
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
|
||||||
|
)
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
|
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||||
|
@ -2385,6 +2398,26 @@ async def generate_key_fn(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||||
)
|
)
|
||||||
|
# check if user set default key/generate params on config.yaml
|
||||||
|
if litellm.default_key_generate_params is not None:
|
||||||
|
for elem in data:
|
||||||
|
key, value = elem
|
||||||
|
if value is None and key in [
|
||||||
|
"max_budget",
|
||||||
|
"user_id",
|
||||||
|
"team_id",
|
||||||
|
"max_parallel_requests",
|
||||||
|
"tpm_limit",
|
||||||
|
"rpm_limit",
|
||||||
|
"budget_duration",
|
||||||
|
]:
|
||||||
|
setattr(
|
||||||
|
data, key, litellm.default_key_generate_params.get(key, None)
|
||||||
|
)
|
||||||
|
elif key == "models" and value == []:
|
||||||
|
setattr(data, key, litellm.default_key_generate_params.get(key, []))
|
||||||
|
elif key == "metadata" and value == {}:
|
||||||
|
setattr(data, key, litellm.default_key_generate_params.get(key, {}))
|
||||||
|
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
|
|
||||||
|
@ -2854,7 +2887,7 @@ async def user_auth(request: Request):
|
||||||
return "Email sent!"
|
return "Email sent!"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/google-login/key/generate", tags=["experimental"])
|
@app.get("/sso/key/generate", tags=["experimental"])
|
||||||
async def google_login(request: Request):
|
async def google_login(request: Request):
|
||||||
"""
|
"""
|
||||||
Create Proxy API Keys using Google Workspace SSO. Requires setting GOOGLE_REDIRECT_URI in .env
|
Create Proxy API Keys using Google Workspace SSO. Requires setting GOOGLE_REDIRECT_URI in .env
|
||||||
|
@ -2863,102 +2896,108 @@ async def google_login(request: Request):
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
GOOGLE_REDIRECT_URI = os.getenv("PROXY_BASE_URL")
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
if GOOGLE_REDIRECT_URI is None:
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", None)
|
||||||
|
if redirect_url is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="PROXY_BASE_URL not set. Set it in .env file",
|
message="PROXY_BASE_URL not set. Set it in .env file",
|
||||||
type="auth_error",
|
type="auth_error",
|
||||||
param="PROXY_BASE_URL",
|
param="PROXY_BASE_URL",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
if GOOGLE_REDIRECT_URI.endswith("/"):
|
|
||||||
GOOGLE_REDIRECT_URI += "google-callback"
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
else:
|
else:
|
||||||
GOOGLE_REDIRECT_URI += "/google-callback"
|
redirect_url += "/sso/callback"
|
||||||
|
# Google SSO Auth
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
if GOOGLE_CLIENT_ID is None:
|
if google_client_secret is None:
|
||||||
GOOGLE_CLIENT_ID = (
|
|
||||||
"246483686424-clje5sggkjma26ilktj6qssakqhoon0m.apps.googleusercontent.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {GOOGLE_REDIRECT_URI}\nGOOGLE_CLIENT_ID: {GOOGLE_CLIENT_ID}"
|
|
||||||
)
|
|
||||||
google_auth_url = f"https://accounts.google.com/o/oauth2/auth?client_id={GOOGLE_CLIENT_ID}&redirect_uri={GOOGLE_REDIRECT_URI}&response_type=code&scope=openid%20profile%20email"
|
|
||||||
return RedirectResponse(url=google_auth_url)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/google-callback", tags=["experimental"], response_model=GenerateKeyResponse)
|
|
||||||
async def google_callback(code: str, request: Request):
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
GOOGLE_REDIRECT_URI = os.getenv("PROXY_BASE_URL")
|
|
||||||
if GOOGLE_REDIRECT_URI is None:
|
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="PROXY_BASE_URL not set. Set it in .env file",
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
type="auth_error",
|
type="auth_error",
|
||||||
param="PROXY_BASE_URL",
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
# Add "/google-callback"" to your callback URL
|
|
||||||
if GOOGLE_REDIRECT_URI.endswith("/"):
|
|
||||||
GOOGLE_REDIRECT_URI += "google-callback"
|
|
||||||
else:
|
|
||||||
GOOGLE_REDIRECT_URI += "/google-callback"
|
|
||||||
|
|
||||||
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
google_sso = GoogleSSO(
|
||||||
if GOOGLE_CLIENT_ID is None:
|
client_id=google_client_id,
|
||||||
GOOGLE_CLIENT_ID = (
|
client_secret=google_client_secret,
|
||||||
"246483686424-clje5sggkjma26ilktj6qssakqhoon0m.apps.googleusercontent.com"
|
redirect_uri=redirect_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
|
|
||||||
if GOOGLE_CLIENT_SECRET is None:
|
|
||||||
GOOGLE_CLIENT_SECRET = "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"
|
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"/google-callback\n GOOGLE_REDIRECT_URI: {GOOGLE_REDIRECT_URI}\n GOOGLE_CLIENT_ID: {GOOGLE_CLIENT_ID}"
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||||
)
|
)
|
||||||
# Exchange code for access token
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
token_url = f"https://oauth2.googleapis.com/token"
|
|
||||||
data = {
|
|
||||||
"code": code,
|
|
||||||
"client_id": GOOGLE_CLIENT_ID,
|
|
||||||
"client_secret": GOOGLE_CLIENT_SECRET,
|
|
||||||
"redirect_uri": GOOGLE_REDIRECT_URI,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
}
|
|
||||||
response = await client.post(token_url, data=data)
|
|
||||||
|
|
||||||
# Process the response, extract user info, etc.
|
with google_sso:
|
||||||
if response.status_code == 200:
|
return await google_sso.get_login_redirect()
|
||||||
access_token = response.json()["access_token"]
|
|
||||||
|
|
||||||
# Fetch user info using the access token
|
# Microsoft SSO Auth
|
||||||
async with httpx.AsyncClient() as client:
|
elif microsoft_client_id is not None:
|
||||||
user_info_url = "https://www.googleapis.com/oauth2/v1/userinfo"
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
|
||||||
user_info_response = await client.get(user_info_url, headers=headers)
|
|
||||||
|
|
||||||
# Process user info response
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
if user_info_response.status_code == 200:
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
user_info = user_info_response.json()
|
if microsoft_client_secret is None:
|
||||||
user_email = user_info.get("email")
|
raise ProxyException(
|
||||||
user_name = user_info.get("name")
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
# we can use user_email on litellm proxy now
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
with microsoft_sso:
|
||||||
|
return await microsoft_sso.get_login_redirect()
|
||||||
|
else:
|
||||||
|
# No Google, Microsoft SSO
|
||||||
|
# Use UI Credentials set in .env
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
# TODO: Handle user info as needed, for example, store it in a database, authenticate the user, etc.
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/login", include_in_schema=False
|
||||||
|
) # hidden since this is a helper for UI sso login
|
||||||
|
async def login(request: Request):
|
||||||
|
try:
|
||||||
|
import multipart
|
||||||
|
except ImportError:
|
||||||
|
subprocess.run(["pip", "install", "python-multipart"])
|
||||||
|
|
||||||
|
form = await request.form()
|
||||||
|
username = str(form.get("username"))
|
||||||
|
password = form.get("password")
|
||||||
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
|
ui_password = os.getenv("UI_PASSWORD")
|
||||||
|
|
||||||
|
if username == ui_username and password == ui_password:
|
||||||
|
user_id = username
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_email, "team_id": "litellm-dashboard"} # type: ignore
|
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
key = response["token"] # type: ignore
|
key = response["token"] # type: ignore
|
||||||
user_id = response["user_id"] # type: ignore
|
user_id = response["user_id"] # type: ignore
|
||||||
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
|
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
|
||||||
|
|
||||||
|
# if user set LITELLM_UI_LINK in .env, use that
|
||||||
|
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
|
||||||
|
if litellm_ui_link_in_env is not None:
|
||||||
|
litellm_dashboard_ui = litellm_ui_link_in_env
|
||||||
|
|
||||||
litellm_dashboard_ui += (
|
litellm_dashboard_ui += (
|
||||||
"?userID="
|
"?userID="
|
||||||
+ user_id
|
+ user_id
|
||||||
|
@ -2968,16 +3007,108 @@ async def google_callback(code: str, request: Request):
|
||||||
+ os.getenv("PROXY_BASE_URL")
|
+ os.getenv("PROXY_BASE_URL")
|
||||||
)
|
)
|
||||||
return RedirectResponse(url=litellm_dashboard_ui)
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Handle user info retrieval error
|
raise ProxyException(
|
||||||
raise HTTPException(
|
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
||||||
status_code=user_info_response.status_code,
|
type="auth_error",
|
||||||
detail=user_info_response.text,
|
param="invalid_credentials",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sso/callback", tags=["experimental"])
|
||||||
|
async def auth_callback(request: Request):
|
||||||
|
"""Verify login"""
|
||||||
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
|
|
||||||
|
redirect_url = os.getenv("PROXY_BASE_URL", None)
|
||||||
|
if redirect_url is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="PROXY_BASE_URL not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="PROXY_BASE_URL",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if redirect_url.endswith("/"):
|
||||||
|
redirect_url += "sso/callback"
|
||||||
else:
|
else:
|
||||||
# Handle the error from the token exchange
|
redirect_url += "/sso/callback"
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
|
||||||
|
if google_client_id is not None:
|
||||||
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
|
if google_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
google_sso = GoogleSSO(
|
||||||
|
client_id=google_client_id,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
client_secret=google_client_secret,
|
||||||
|
)
|
||||||
|
result = await google_sso.verify_and_process(request)
|
||||||
|
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if microsoft_tenant is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_TENANT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
|
|
||||||
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
user_id = getattr(result, "email", None)
|
||||||
|
if user_id is None:
|
||||||
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
|
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
key = response["token"] # type: ignore
|
||||||
|
user_id = response["user_id"] # type: ignore
|
||||||
|
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
|
||||||
|
|
||||||
|
# if user set LITELLM_UI_LINK in .env, use that
|
||||||
|
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
|
||||||
|
if litellm_ui_link_in_env is not None:
|
||||||
|
litellm_dashboard_ui = litellm_ui_link_in_env
|
||||||
|
|
||||||
|
litellm_dashboard_ui += (
|
||||||
|
"?userID="
|
||||||
|
+ user_id
|
||||||
|
+ "&accessToken="
|
||||||
|
+ key
|
||||||
|
+ "&proxyBaseUrl="
|
||||||
|
+ os.getenv("PROXY_BASE_URL")
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
@ -3589,6 +3720,8 @@ async def health_readiness():
|
||||||
cache_type = None
|
cache_type = None
|
||||||
if litellm.cache is not None:
|
if litellm.cache is not None:
|
||||||
cache_type = litellm.cache.type
|
cache_type = litellm.cache.type
|
||||||
|
from litellm._version import version
|
||||||
|
|
||||||
if prisma_client is not None: # if db passed in, check if it's connected
|
if prisma_client is not None: # if db passed in, check if it's connected
|
||||||
if prisma_client.db.is_connected() == True:
|
if prisma_client.db.is_connected() == True:
|
||||||
response_object = {"db": "connected"}
|
response_object = {"db": "connected"}
|
||||||
|
@ -3597,6 +3730,7 @@ async def health_readiness():
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"db": "connected",
|
"db": "connected",
|
||||||
"cache": cache_type,
|
"cache": cache_type,
|
||||||
|
"litellm_version": version,
|
||||||
"success_callbacks": litellm.success_callback,
|
"success_callbacks": litellm.success_callback,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -3604,6 +3738,7 @@ async def health_readiness():
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"db": "Not connected",
|
"db": "Not connected",
|
||||||
"cache": cache_type,
|
"cache": cache_type,
|
||||||
|
"litellm_version": version,
|
||||||
"success_callbacks": litellm.success_callback,
|
"success_callbacks": litellm.success_callback,
|
||||||
}
|
}
|
||||||
raise HTTPException(status_code=503, detail="Service Unhealthy")
|
raise HTTPException(status_code=503, detail="Service Unhealthy")
|
||||||
|
|
|
@ -21,6 +21,7 @@ from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||||
|
|
||||||
|
@ -96,6 +97,7 @@ class ProxyLogging:
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
3. /image/generation
|
3. /image/generation
|
||||||
"""
|
"""
|
||||||
|
print_verbose(f"Inside Proxy Logging Pre-call hook!")
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(self.response_taking_too_long(request_data=data))
|
asyncio.create_task(self.response_taking_too_long(request_data=data))
|
||||||
|
|
||||||
|
@ -1035,7 +1037,7 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html):
|
||||||
print_verbose(f"SMTP Connection Init")
|
print_verbose(f"SMTP Connection Init")
|
||||||
# Establish a secure connection with the SMTP server
|
# Establish a secure connection with the SMTP server
|
||||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||||
if os.getenv("SMTP_TLS", 'True') != "False":
|
if os.getenv("SMTP_TLS", "True") != "False":
|
||||||
server.starttls()
|
server.starttls()
|
||||||
|
|
||||||
# Login to your email account
|
# Login to your email account
|
||||||
|
@ -1206,3 +1208,67 @@ async def reset_budget(prisma_client: PrismaClient):
|
||||||
await prisma_client.update_data(
|
await prisma_client.update_data(
|
||||||
query_type="update_many", data_list=users_to_reset, table_name="user"
|
query_type="update_many", data_list=users_to_reset, table_name="user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# LiteLLM Admin UI - Non SSO Login
|
||||||
|
html_form = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>LiteLLM Login</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
background-color: #f4f4f4;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
form {
|
||||||
|
background-color: #fff;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 8px;
|
||||||
|
margin-bottom: 16px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="submit"] {
|
||||||
|
background-color: #4caf50;
|
||||||
|
color: #fff;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="submit"]:hover {
|
||||||
|
background-color: #45a049;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form action="/login" method="post">
|
||||||
|
<h2>LiteLLM Login</h2>
|
||||||
|
<label for="username">Username:</label>
|
||||||
|
<input type="text" id="username" name="username" required>
|
||||||
|
<label for="password">Password:</label>
|
||||||
|
<input type="password" id="password" name="password" required>
|
||||||
|
<input type="submit" value="Submit">
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
|
@ -1411,6 +1411,12 @@ class Router:
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
litellm_params["max_retries"] = max_retries
|
litellm_params["max_retries"] = max_retries
|
||||||
|
|
||||||
|
organization = litellm_params.get("organization", None)
|
||||||
|
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||||
|
organization_env_name = organization.replace("os.environ/", "")
|
||||||
|
organization = litellm.get_secret(organization_env_name)
|
||||||
|
litellm_params["organization"] = organization
|
||||||
|
|
||||||
if "azure" in model_name:
|
if "azure" in model_name:
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1610,6 +1616,7 @@ class Router:
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -1630,6 +1637,7 @@ class Router:
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -1651,6 +1659,7 @@ class Router:
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
transport=AsyncCustomHTTPTransport(),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
@ -1672,6 +1681,7 @@ class Router:
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(),
|
transport=CustomHTTPTransport(),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
|
|
|
@ -70,18 +70,16 @@ def test_completion_with_empty_model():
|
||||||
|
|
||||||
def test_completion_invalid_param_cohere():
|
def test_completion_invalid_param_cohere():
|
||||||
try:
|
try:
|
||||||
response = completion(model="command-nightly", messages=messages, top_p=1)
|
litellm.set_verbose = True
|
||||||
print(f"response: {response}")
|
response = completion(model="command-nightly", messages=messages, seed=12)
|
||||||
|
pytest.fail(f"This should have failed cohere does not support `seed` parameter")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Unsupported parameters passed: top_p" in str(e):
|
if " cohere does not support parameters: {'seed': 12}" in str(e):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"An error occurred {e}")
|
pytest.fail(f"An error occurred {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_completion_invalid_param_cohere()
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_function_call_cohere():
|
def test_completion_function_call_cohere():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -515,7 +515,7 @@ def hf_test_completion_tgi():
|
||||||
# hf_test_error_logs()
|
# hf_test_error_logs()
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky
|
# def test_completion_cohere(): # commenting out,for now as the cohere endpoint is being flaky
|
||||||
# try:
|
# try:
|
||||||
# litellm.CohereConfig(max_tokens=10, stop_sequences=["a"])
|
# litellm.CohereConfig(max_tokens=10, stop_sequences=["a"])
|
||||||
# response = completion(
|
# response = completion(
|
||||||
|
@ -569,6 +569,22 @@ def test_completion_openai():
|
||||||
# test_completion_openai()
|
# test_completion_openai()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_openai_organization():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
|
||||||
|
)
|
||||||
|
pytest.fail("Request should have failed - This organization does not exist")
|
||||||
|
except Exception as e:
|
||||||
|
assert "No such organization: org-ikDc4ex8NB" in str(e)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_text_openai():
|
def test_completion_text_openai():
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
|
|
|
@ -302,6 +302,25 @@ def test_bedrock_embedding_cohere():
|
||||||
# test_bedrock_embedding_cohere()
|
# test_bedrock_embedding_cohere()
|
||||||
|
|
||||||
|
|
||||||
|
def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
litellm.BadRequestError,
|
||||||
|
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
):
|
||||||
|
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
litellm.BadRequestError,
|
||||||
|
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
):
|
||||||
|
litellm.embedding(
|
||||||
|
model="amazon.titan-embed-text-v1",
|
||||||
|
input=[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# comment out hf tests - since hf endpoints are unstable
|
# comment out hf tests - since hf endpoints are unstable
|
||||||
def test_hf_embedding():
|
def test_hf_embedding():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -472,3 +472,32 @@ def test_call_with_key_over_budget_stream(custom_db_client):
|
||||||
error_detail = e.message
|
error_detail = e.message
|
||||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamo_db_migration(custom_db_client):
|
||||||
|
# Tests the temporary patch we have in place
|
||||||
|
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_custom_auth", None)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
bearer_token = (
|
||||||
|
"Bearer " + "sk-elJDL2pOEjcAuC7zD4psAg"
|
||||||
|
) # this works with ishaan's db, it's a never expiring key
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
return b'{"model": "azure-models"}'
|
||||||
|
|
||||||
|
request.body = return_body
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
|
@ -1188,3 +1188,27 @@ async def test_key_name_set(prisma_client):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Got Exception", e)
|
print("Got Exception", e)
|
||||||
pytest.fail(f"Got exception {e}")
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_default_key_params(prisma_client):
|
||||||
|
"""
|
||||||
|
- create key
|
||||||
|
- get key info
|
||||||
|
- assert key_name is not null
|
||||||
|
"""
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
|
||||||
|
litellm.default_key_generate_params = {"max_budget": 0.000122}
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
try:
|
||||||
|
request = GenerateKeyRequest()
|
||||||
|
key = await generate_key_fn(request)
|
||||||
|
generated_key = key.key
|
||||||
|
result = await info_key_fn(key=generated_key)
|
||||||
|
print("result from info_key_fn", result)
|
||||||
|
assert result["info"]["max_budget"] == 0.000122
|
||||||
|
except Exception as e:
|
||||||
|
print("Got Exception", e)
|
||||||
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
|
@ -456,6 +456,7 @@ async def test_streaming_router_call():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streaming_router_tpm_limit():
|
async def test_streaming_router_tpm_limit():
|
||||||
|
litellm.set_verbose = True
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
"model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
|
@ -520,7 +521,7 @@ async def test_streaming_router_tpm_limit():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
await asyncio.sleep(1) # success is done in a separate thread
|
await asyncio.sleep(5) # success is done in a separate thread
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await parallel_request_handler.async_pre_call_hook(
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
|
|
@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
|
||||||
print("passed")
|
print("passed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_with_organization():
|
||||||
|
try:
|
||||||
|
print("Testing OpenAI with organization")
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-bad-org",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"organization": "org-ikDc4ex8NB",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai-good-org",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
print(router.model_list)
|
||||||
|
print(router.model_list[0])
|
||||||
|
|
||||||
|
openai_client = router._get_client(
|
||||||
|
deployment=router.model_list[0],
|
||||||
|
kwargs={"input": ["hello"], "model": "openai-bad-org"},
|
||||||
|
)
|
||||||
|
print(vars(openai_client))
|
||||||
|
|
||||||
|
assert openai_client.organization == "org-ikDc4ex8NB"
|
||||||
|
|
||||||
|
# bad org raises error
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = router.completion(
|
||||||
|
model="openai-bad-org",
|
||||||
|
messages=[{"role": "user", "content": "this is a test"}],
|
||||||
|
)
|
||||||
|
pytest.fail("Request should have failed - This organization does not exist")
|
||||||
|
except Exception as e:
|
||||||
|
print("Got exception: " + str(e))
|
||||||
|
assert "No such organization: org-ikDc4ex8NB" in str(e)
|
||||||
|
|
||||||
|
# good org works
|
||||||
|
response = router.completion(
|
||||||
|
model="openai-good-org",
|
||||||
|
messages=[{"role": "user", "content": "this is a test"}],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -2929,32 +2929,10 @@ def cost_per_token(
|
||||||
model_with_provider_and_region in model_cost_ref
|
model_with_provider_and_region in model_cost_ref
|
||||||
): # use region based pricing, if it's available
|
): # use region based pricing, if it's available
|
||||||
model_with_provider = model_with_provider_and_region
|
model_with_provider = model_with_provider_and_region
|
||||||
|
if model_with_provider in model_cost_ref:
|
||||||
|
model = model_with_provider
|
||||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
print_verbose(f"Looking up model={model} in model_cost_map")
|
print_verbose(f"Looking up model={model} in model_cost_map")
|
||||||
if model_with_provider in model_cost_ref:
|
|
||||||
print_verbose(
|
|
||||||
f"Success: model={model_with_provider} in model_cost_map - {model_cost_ref[model_with_provider]}"
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"applying cost={model_cost_ref[model_with_provider].get('input_cost_per_token', None)} for prompt_tokens={prompt_tokens}"
|
|
||||||
)
|
|
||||||
prompt_tokens_cost_usd_dollar = (
|
|
||||||
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"calculated prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}"
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"applying cost={model_cost_ref[model_with_provider].get('output_cost_per_token', None)} for completion_tokens={completion_tokens}"
|
|
||||||
)
|
|
||||||
completion_tokens_cost_usd_dollar = (
|
|
||||||
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
|
||||||
* completion_tokens
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"calculated completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
|
||||||
)
|
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
|
||||||
if model in model_cost_ref:
|
if model in model_cost_ref:
|
||||||
print_verbose(f"Success: model={model} in model_cost_map")
|
print_verbose(f"Success: model={model} in model_cost_map")
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -7509,7 +7487,10 @@ class CustomStreamWrapper:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
original_chunk = None # this is used for function/tool calling
|
original_chunk = None # this is used for function/tool calling
|
||||||
if len(str_line.choices) > 0:
|
if len(str_line.choices) > 0:
|
||||||
if str_line.choices[0].delta.content is not None:
|
if (
|
||||||
|
str_line.choices[0].delta is not None
|
||||||
|
and str_line.choices[0].delta.content is not None
|
||||||
|
):
|
||||||
text = str_line.choices[0].delta.content
|
text = str_line.choices[0].delta.content
|
||||||
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
|
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
|
||||||
original_chunk = str_line
|
original_chunk = str_line
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.20.2"
|
version = "1.20.6"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.20.2"
|
version = "1.20.6"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -70,10 +70,11 @@ async def test_health_readiness():
|
||||||
url = "http://0.0.0.0:4000/health/readiness"
|
url = "http://0.0.0.0:4000/health/readiness"
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
response_text = await response.text()
|
response_json = await response.json()
|
||||||
|
|
||||||
print(response_text)
|
print(response_json)
|
||||||
print()
|
assert "litellm_version" in response_json
|
||||||
|
assert "status" in response_json
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
# Use official Python base image
|
# Use an official Node.js image as the base image
|
||||||
FROM python:3.9.12
|
FROM node:18-alpine
|
||||||
|
|
||||||
EXPOSE 8501
|
# Set the working directory inside the container
|
||||||
|
|
||||||
# Set the working directory in the container
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy the requirements.txt file to the container
|
# Copy package.json and package-lock.json to the working directory
|
||||||
COPY requirements.txt .
|
COPY ./litellm-dashboard/package*.json ./
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install dependencies
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN npm install
|
||||||
|
|
||||||
# Copy the entire project directory to the container
|
# Copy the rest of the application code to the working directory
|
||||||
COPY admin.py .
|
COPY ./litellm-dashboard .
|
||||||
|
|
||||||
# Set the entrypoint command to run admin.py with Streamlit
|
# Expose the port that the Next.js app will run on
|
||||||
ENTRYPOINT [ "streamlit", "run"]
|
EXPOSE 3000
|
||||||
CMD ["admin.py"]
|
|
||||||
|
# Start the Next.js app
|
||||||
|
CMD ["npm", "run", "dev"]
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
|
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import React, { useState, useEffect, useRef } from "react";
|
import React, { useState, useEffect, useRef } from "react";
|
||||||
import { Button, TextInput, Grid, Col } from "@tremor/react";
|
import { Button, TextInput, Grid, Col } from "@tremor/react";
|
||||||
import { message } from "antd";
|
|
||||||
import { Card, Metric, Text } from "@tremor/react";
|
import { Card, Metric, Text } from "@tremor/react";
|
||||||
|
import { Button as Button2, Modal, Form, Input, InputNumber, Select, message } from "antd";
|
||||||
import { keyCreateCall } from "./networking";
|
import { keyCreateCall } from "./networking";
|
||||||
// Define the props type
|
|
||||||
|
const { Option } = Select;
|
||||||
|
|
||||||
interface CreateKeyProps {
|
interface CreateKeyProps {
|
||||||
userID: string;
|
userID: string;
|
||||||
accessToken: string;
|
accessToken: string;
|
||||||
|
@ -14,8 +17,6 @@ interface CreateKeyProps {
|
||||||
setData: React.Dispatch<React.SetStateAction<any[] | null>>;
|
setData: React.Dispatch<React.SetStateAction<any[] | null>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
import { Modal, Button as Button2 } from "antd";
|
|
||||||
|
|
||||||
const CreateKey: React.FC<CreateKeyProps> = ({
|
const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
userID,
|
userID,
|
||||||
accessToken,
|
accessToken,
|
||||||
|
@ -23,49 +24,106 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
data,
|
data,
|
||||||
setData,
|
setData,
|
||||||
}) => {
|
}) => {
|
||||||
|
const [form] = Form.useForm();
|
||||||
const [isModalVisible, setIsModalVisible] = useState(false);
|
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||||
const [apiKey, setApiKey] = useState(null);
|
const [apiKey, setApiKey] = useState(null);
|
||||||
|
|
||||||
const handleOk = () => {
|
const handleOk = () => {
|
||||||
// Handle the OK action
|
|
||||||
console.log("OK Clicked");
|
|
||||||
setIsModalVisible(false);
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleCancel = () => {
|
const handleCancel = () => {
|
||||||
// Handle the cancel action or closing the modal
|
|
||||||
console.log("Modal closed");
|
|
||||||
setIsModalVisible(false);
|
setIsModalVisible(false);
|
||||||
setApiKey(null);
|
setApiKey(null);
|
||||||
|
form.resetFields();
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleCreate = async () => {
|
const handleCreate = async (formValues: Record<string, any>) => {
|
||||||
if (data == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
try {
|
try {
|
||||||
message.info("Making API Call");
|
message.info("Making API Call");
|
||||||
|
// Check if "models" exists and is not an empty string
|
||||||
|
if (formValues.models && formValues.models.trim() !== '') {
|
||||||
|
// Format the "models" field as an array
|
||||||
|
formValues.models = formValues.models.split(',').map((model: string) => model.trim());
|
||||||
|
} else {
|
||||||
|
// If "models" is undefined or an empty string, set it to an empty array
|
||||||
|
formValues.models = [];
|
||||||
|
}
|
||||||
setIsModalVisible(true);
|
setIsModalVisible(true);
|
||||||
const response = await keyCreateCall(proxyBaseUrl, accessToken, userID);
|
const response = await keyCreateCall(proxyBaseUrl, accessToken, userID, formValues);
|
||||||
// Successfully completed the deletion. Update the state to trigger a rerender.
|
setData((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
|
||||||
setData([...data, response]);
|
|
||||||
setApiKey(response["key"]);
|
setApiKey(response["key"]);
|
||||||
message.success("API Key Created");
|
message.success("API Key Created");
|
||||||
|
form.resetFields();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error deleting the key:", error);
|
console.error("Error creating the key:", error);
|
||||||
// Handle any error situations, such as displaying an error message to the user.
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
<Button className="mx-auto" onClick={handleCreate}>
|
<Button className="mx-auto" onClick={() => setIsModalVisible(true)}>
|
||||||
+ Create New Key
|
+ Create New Key
|
||||||
</Button>
|
</Button>
|
||||||
<Modal
|
<Modal
|
||||||
title="Save your key"
|
title="Create Key"
|
||||||
open={isModalVisible}
|
visible={isModalVisible}
|
||||||
|
width={800}
|
||||||
|
footer={null}
|
||||||
onOk={handleOk}
|
onOk={handleOk}
|
||||||
onCancel={handleCancel}
|
onCancel={handleCancel}
|
||||||
|
>
|
||||||
|
<Form form={form} onFinish={handleCreate} labelCol={{ span: 6 }} wrapperCol={{ span: 16 }} labelAlign="left">
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="Key Name"
|
||||||
|
name="key_alias"
|
||||||
|
>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="Models (Comma Separated). Eg: gpt-3.5-turbo,gpt-4"
|
||||||
|
name="models"
|
||||||
|
>
|
||||||
|
<Input placeholder="gpt-4,gpt-3.5-turbo" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="Max Budget (USD)"
|
||||||
|
name="max_budget"
|
||||||
|
>
|
||||||
|
<InputNumber step={0.01} precision={2} width={200}/>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="Duration (eg: 30s, 30h, 30d)"
|
||||||
|
name="duration"
|
||||||
|
>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="Metadata"
|
||||||
|
name="metadata"
|
||||||
|
>
|
||||||
|
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
|
||||||
|
</Form.Item>
|
||||||
|
<div style={{ textAlign: 'right', marginTop: '10px' }}>
|
||||||
|
<Button2 htmlType="submit">
|
||||||
|
Create Key
|
||||||
|
</Button2>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
{apiKey && (
|
||||||
|
<Modal
|
||||||
|
title="Save your key"
|
||||||
|
visible={isModalVisible}
|
||||||
|
onOk={handleOk}
|
||||||
|
onCancel={handleCancel}
|
||||||
|
footer={null}
|
||||||
>
|
>
|
||||||
<Grid numItems={1} className="gap-2 w-full">
|
<Grid numItems={1} className="gap-2 w-full">
|
||||||
<Col numColSpan={1}>
|
<Col numColSpan={1}>
|
||||||
|
@ -85,6 +143,7 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,11 +3,13 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export const keyCreateCall = async (
|
export const keyCreateCall = async (
|
||||||
proxyBaseUrl: String,
|
proxyBaseUrl: string,
|
||||||
accessToken: String,
|
accessToken: string,
|
||||||
userID: String
|
userID: string,
|
||||||
|
formValues: Record<string, any> // Assuming formValues is an object
|
||||||
) => {
|
) => {
|
||||||
try {
|
try {
|
||||||
|
console.log("Form Values in keyCreateCall:", formValues); // Log the form values before making the API call
|
||||||
const response = await fetch(`${proxyBaseUrl}/key/generate`, {
|
const response = await fetch(`${proxyBaseUrl}/key/generate`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
|
@ -15,18 +17,19 @@ export const keyCreateCall = async (
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
team_id: "core-infra-4",
|
|
||||||
max_budget: 10,
|
|
||||||
user_id: userID,
|
user_id: userID,
|
||||||
|
...formValues, // Include formValues in the request body
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json();
|
||||||
|
console.error("Error response from the server:", errorData);
|
||||||
throw new Error("Network response was not ok");
|
throw new Error("Network response was not ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
console.log(data);
|
console.log("API Response:", data);
|
||||||
return data;
|
return data;
|
||||||
// Handle success - you might want to update some state or UI based on the created key
|
// Handle success - you might want to update some state or UI based on the created key
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
@ -35,6 +38,7 @@ export const keyCreateCall = async (
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
export const keyDeleteCall = async (
|
export const keyDeleteCall = async (
|
||||||
proxyBaseUrl: String,
|
proxyBaseUrl: String,
|
||||||
accessToken: String,
|
accessToken: String,
|
||||||
|
|
|
@ -43,11 +43,10 @@ const UserDashboard = () => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else if (userID == null || accessToken == null) {
|
else if (userID == null || accessToken == null) {
|
||||||
// redirect to page: ProxyBaseUrl/google-login/key/generate
|
|
||||||
const baseUrl = proxyBaseUrl.endsWith('/') ? proxyBaseUrl : proxyBaseUrl + '/';
|
const baseUrl = proxyBaseUrl.endsWith('/') ? proxyBaseUrl : proxyBaseUrl + '/';
|
||||||
|
|
||||||
// Now you can construct the full URL
|
// Now you can construct the full URL
|
||||||
const url = `${baseUrl}google-login/key/generate`;
|
const url = `${baseUrl}sso/key/generate`;
|
||||||
|
|
||||||
window.location.href = url;
|
window.location.href = url;
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
||||||
<TableHead>
|
<TableHead>
|
||||||
<TableRow>
|
<TableRow>
|
||||||
<TableHeaderCell>Secret Key</TableHeaderCell>
|
<TableHeaderCell>Secret Key</TableHeaderCell>
|
||||||
<TableHeaderCell>Spend</TableHeaderCell>
|
<TableHeaderCell>Spend (USD)</TableHeaderCell>
|
||||||
|
<TableHeaderCell>Key Budget (USD)</TableHeaderCell>
|
||||||
<TableHeaderCell>Expires</TableHeaderCell>
|
<TableHeaderCell>Expires</TableHeaderCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHead>
|
</TableHead>
|
||||||
|
@ -68,11 +69,24 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
||||||
return (
|
return (
|
||||||
<TableRow key={item.token}>
|
<TableRow key={item.token}>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
|
{item.key_name != null ? (
|
||||||
<Text>{item.key_name}</Text>
|
<Text>{item.key_name}</Text>
|
||||||
|
) : (
|
||||||
|
<Text>{item.token}</Text>
|
||||||
|
)
|
||||||
|
}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Text>{item.spend}</Text>
|
<Text>{item.spend}</Text>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{item.max_budget != null ? (
|
||||||
|
<Text>{item.max_budget}</Text>
|
||||||
|
) : (
|
||||||
|
<Text>Unlimited Budget</Text>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
{item.expires != null ? (
|
{item.expires != null ? (
|
||||||
<Text>{item.expires}</Text>
|
<Text>{item.expires}</Text>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue