Merge branch 'BerriAI:main' into add_helm_chart

This commit is contained in:
Shaun Maher 2024-01-31 14:27:21 +11:00 committed by GitHub
commit 8876b47e94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 860 additions and 237 deletions

View file

@ -156,7 +156,8 @@ jobs:
--config /app/config.yaml \
--port 4000 \
--num_workers 8 \
--debug
--debug \
--run_gunicorn \
- run:
name: Install curl and dockerize
command: |

View file

@ -52,4 +52,4 @@ RUN chmod +x entrypoint.sh
EXPOSE 4000/tcp
ENTRYPOINT ["litellm"]
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml", "--detailed_debug"]
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml", "--detailed_debug", "--run_gunicorn"]

View file

@ -56,4 +56,4 @@ EXPOSE 4000/tcp
# # Set your entrypoint and command
ENTRYPOINT ["litellm"]
CMD ["--port", "4000"]
CMD ["--port", "4000", "--run_gunicorn"]

View file

@ -188,7 +188,7 @@ print(response)
</Tabs>
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
@ -210,6 +210,12 @@ model_list:
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
temperature: 0.2
- model_name: openai-gpt-3.5
litellm_params:
model: openai/gpt-3.5-turbo
api_key: sk-123
organization: org-ikDc4ex8NB
temperature: 0.2
- model_name: mistral-7b
litellm_params:
model: ollama/mistral
@ -483,3 +489,55 @@ general_settings:
max_parallel_requests: 100 # max parallel requests for a user = 100
```
## All settings
```python
{
"environment_variables": {},
"model_list": [
{
"model_name": "string",
"litellm_params": {},
"model_info": {
"id": "string",
"mode": "embedding",
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"max_tokens": 2048,
"base_model": "gpt-4-1106-preview",
"additionalProp1": {}
}
}
],
"litellm_settings": {}, # ALL (https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)
"general_settings": {
"completion_model": "string",
"key_management_system": "google_kms", # either google_kms or azure_kms
"master_key": "string",
"database_url": "string",
"database_type": "dynamo_db",
"database_args": {
"billing_mode": "PROVISIONED_THROUGHPUT",
"read_capacity_units": 0,
"write_capacity_units": 0,
"ssl_verify": true,
"region_name": "string",
"user_table_name": "LiteLLM_UserTable",
"key_table_name": "LiteLLM_VerificationToken",
"config_table_name": "LiteLLM_Config",
"spend_table_name": "LiteLLM_SpendLogs"
},
"otel": true,
"custom_auth": "string",
"max_parallel_requests": 0,
"infer_model_from_keys": true,
"background_health_checks": true,
"health_check_interval": 300,
"alerting": [
"string"
],
"alerting_threshold": 0
}
}
```

View file

@ -0,0 +1,34 @@
# Debugging
2 levels of debugging supported.
- debug (prints info logs)
- detailed debug (prints debug logs)
## `debug`
**via cli**
```bash
$ litellm --debug
```
**via env**
```python
os.environ["LITELLM_LOG"] = "INFO"
```
## `detailed debug`
**via cli**
```bash
$ litellm --detailed_debug
```
**via env**
```python
os.environ["LITELLM_LOG"] = "DEBUG"
```

View file

@ -112,7 +112,8 @@ Example Response:
```json
{
"status": "healthy",
"db": "connected"
"db": "connected",
"litellm_version":"1.19.2",
}
```
@ -121,7 +122,8 @@ Example Response:
```json
{
"status": "healthy",
"db": "Not connected"
"db": "Not connected",
"litellm_version":"1.19.2",
}
```

View file

@ -1,4 +1,6 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Admin UI
@ -26,22 +28,27 @@ general_settings:
allow_user_auth: true
```
## 2. Setup Google SSO - Use this to Authenticate Team Members to the UI
- Create an Oauth 2.0 Client
<Image img={require('../../img/google_oauth2.png')} />
## 2. Setup SSO/Auth for UI
- Navigate to Google `Credenentials`
- Create a new Oauth client ID
- Set the `GOOGLE_CLIENT_ID` and `GOOGLE_CLIENT_SECRET` in your Proxy .env
- Set Redirect URL on your Oauth 2.0 Client
- Click on your Oauth 2.0 client on https://console.cloud.google.com/
- Set a redirect url = `<your proxy base url>/google-callback`
```
https://litellm-production-7002.up.railway.app/google-callback
```
<Image img={require('../../img/google_redirect.png')} />
## 3. Required env variables on your Proxy
<Tabs>
<TabItem value="username" label="Quick Start - Username, Password">
Set the following in your .env on the Proxy
```shell
UI_USERNAME=ishaan-litellm
UI_PASSWORD=langchain
```
On accessing the LiteLLM UI, you will be prompted to enter your username, password
</TabItem>
<TabItem value="google" label="Google SSO">
- Create a new Oauth 2.0 Client on https://console.cloud.google.com/
**Required .env variables on your Proxy**
```shell
PROXY_BASE_URL="<your deployed proxy endpoint>" example PROXY_BASE_URL=https://litellm-production-7002.up.railway.app/
@ -50,6 +57,37 @@ GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
```
- Set Redirect URL on your Oauth 2.0 Client on https://console.cloud.google.com/
- Set a redirect url = `<your proxy base url>/sso/callback`
```shell
https://litellm-production-7002.up.railway.app/sso/callback
```
</TabItem>
<TabItem value="msft" label="Microsoft SSO">
- Create a new App Registration on https://portal.azure.com/
- Create a client Secret for your App Registration
**Required .env variables on your Proxy**
```shell
PROXY_BASE_URL="<your deployed proxy endpoint>" example PROXY_BASE_URL=https://litellm-production-7002.up.railway.app/
MICROSOFT_CLIENT_ID="84583a4d-"
MICROSOFT_CLIENT_SECRET="nbk8Q~"
MICROSOFT_TENANT="5a39737
```
- Set Redirect URI on your App Registration on https://portal.azure.com/
- Set a redirect url = `<your proxy base url>/sso/callback`
```shell
http://localhost:4000/sso/callback
```
</TabItem>
</Tabs>
## 4. Use UI
👉 Get Started here: https://litellm-dashboard.vercel.app/

View file

@ -278,6 +278,21 @@ Request Params:
}
```
## Default /key/generate params
Use this, if you need to control the default `max_budget` or any `key/generate` param per key.
When a `/key/generate` request does not specify `max_budget`, it will use the `max_budget` specified in `default_key_generate_params`
Set `litellm_settings:default_key_generate_params`:
```yaml
litellm_settings:
default_key_generate_params:
max_budget: 1.5000
models: ["azure-gpt-3.5"]
duration: # blank means `null`
metadata: {"setting":"default"}
team_id: "core-infra"
```
## Set Budgets - Per Key
Set `max_budget` in (USD $) param in the `key/generate` request. By default the `max_budget` is set to `null` and is not checked for keys

View file

@ -115,6 +115,7 @@ const sidebars = {
"proxy/ui",
"proxy/model_management",
"proxy/health",
"proxy/debugging",
{
"type": "category",
"label": "🔥 Load Balancing",

View file

@ -143,6 +143,7 @@ model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/mai
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None
#### RELIABILITY ####
request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint

View file

@ -702,6 +702,11 @@ def _embedding_func_single(
encoding=None,
logging_obj=None,
):
if isinstance(input, str) is False:
raise BedrockError(
message="Bedrock Embedding API input must be type str | List[str]",
status_code=400,
)
# logic for parsing in - calling - parsing out model embedding calls
## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0]
@ -795,7 +800,8 @@ def embedding(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
if type(input) == str:
if isinstance(input, str):
## Embedding Call
embeddings = [
_embedding_func_single(
model,
@ -805,8 +811,8 @@ def embedding(
logging_obj=logging_obj,
)
]
else:
## Embedding Call
elif isinstance(input, list):
## Embedding Call - assuming this is a List[str]
embeddings = [
_embedding_func_single(
model,
@ -817,6 +823,12 @@ def embedding(
)
for i in input
] # [TODO]: make these parallel calls
else:
# enters this branch if input = int, ex. input=2
raise BedrockError(
message="Bedrock Embedding API input must be type str | List[str]",
status_code=400,
)
## Populate OpenAI compliant dictionary
embedding_response = []

View file

@ -145,8 +145,8 @@ def get_ollama_response(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = {"model": model, "messages": messages, **optional_params}
stream = optional_params.pop("stream", False)
data = {"model": model, "messages": messages, "options": optional_params}
## LOGGING
logging_obj.pre_call(
input=None,
@ -159,7 +159,7 @@ def get_ollama_response(
},
)
if acompletion is True:
if optional_params.get("stream", False) == True:
if stream == True:
response = ollama_async_streaming(
url=url,
data=data,
@ -176,7 +176,7 @@ def get_ollama_response(
logging_obj=logging_obj,
)
return response
elif optional_params.get("stream", False) == True:
elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(

View file

@ -221,6 +221,7 @@ class OpenAIChatCompletion(BaseLLM):
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
return self.acompletion(
@ -266,6 +268,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
elif optional_params.get("stream", False):
return self.streaming(
@ -278,6 +281,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
@ -291,6 +295,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
@ -358,6 +363,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
@ -372,6 +378,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
@ -412,6 +419,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
@ -423,6 +431,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
@ -454,6 +463,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
@ -467,6 +477,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
@ -748,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
):
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation":
raise Exception("model is not set")

View file

@ -104,8 +104,8 @@ def mistral_instruct_pt(messages):
initial_prompt_value="<s>",
role_dict={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "<</SYS>> [/INST]\n",
"pre_message": "[INST] \n",
"post_message": " [/INST]\n",
},
"user": {"pre_message": "[INST] ", "post_message": " [/INST]\n"},
"assistant": {"pre_message": " ", "post_message": " "},
@ -376,6 +376,7 @@ def anthropic_pt(
You can "put words in Claude's mouth" by ending with an assistant message.
See: https://docs.anthropic.com/claude/docs/put-words-in-claudes-mouth
"""
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
@ -403,27 +404,30 @@ def _load_image_from_url(image_url):
try:
from PIL import Image
except:
raise Exception("gemini image conversion failed please run `pip install Pillow`")
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
from io import BytesIO
try:
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get('content-type')
if not content_type or 'image' not in content_type:
raise ValueError(f"URL does not point to a valid image (content-type: {content_type})")
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Load the image from the response content
return Image.open(BytesIO(response.content))
except requests.RequestException as e:
print(f"Request failed: {e}")
except UnidentifiedImageError:
print("Cannot identify image file (it may not be a supported image format or might be corrupted).")
except ValueError as e:
print(e)
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _gemini_vision_convert_messages(messages: list):
@ -441,10 +445,11 @@ def _gemini_vision_convert_messages(messages: list):
try:
from PIL import Image
except:
raise Exception("gemini image conversion failed please run `pip install Pillow`")
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
try:
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
prompt = ""

View file

@ -237,8 +237,11 @@ def completion(
GenerationConfig,
)
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
import google.auth
vertexai.init(project=vertex_project, location=vertex_location)
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
creds, _ = google.auth.default(quota_project_id=vertex_project)
vertexai.init(project=vertex_project, location=vertex_location, credentials=creds)
## Load Config
config = litellm.VertexAIConfig.get_config()

View file

@ -450,6 +450,7 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -591,26 +592,37 @@ def completion(
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
print_verbose(f"Registering model={model} in model cost map")
litellm.register_model(
{
f"{custom_llm_provider}/{model}": {
"input_cost_per_token": input_cost_per_token,
"output_cost_per_token": output_cost_per_token,
"litellm_provider": custom_llm_provider,
},
model: {
"input_cost_per_token": input_cost_per_token,
"output_cost_per_token": output_cost_per_token,
"litellm_provider": custom_llm_provider,
}
},
}
)
if (
elif (
input_cost_per_second is not None
): # time based pricing just needs cost in place
output_cost_per_second = output_cost_per_second or 0.0
litellm.register_model(
{
f"{custom_llm_provider}/{model}": {
"input_cost_per_second": input_cost_per_second,
"output_cost_per_second": output_cost_per_second,
"litellm_provider": custom_llm_provider,
},
model: {
"input_cost_per_second": input_cost_per_second,
"output_cost_per_second": output_cost_per_second,
"litellm_provider": custom_llm_provider,
}
},
}
)
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
@ -787,7 +799,8 @@ def completion(
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
@ -827,6 +840,7 @@ def completion(
timeout=timeout,
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
)
except Exception as e:
## LOGGING - log the original exception returned
@ -3224,6 +3238,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
model_params.get("timeout")
@ -3241,6 +3256,7 @@ async def ahealth_check(
mode=mode,
prompt=prompt,
input=input,
organization=organization,
)
else:
if mode == "embedding":
@ -3265,6 +3281,7 @@ async def ahealth_check(
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):
try:
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:

View file

@ -233,10 +233,6 @@ class DynamoDBWrapper(CustomDB):
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
if key_name == "token" and key.startswith("sk-"):
# ensure it's hashed
key = hash_token(token=key)
response = await table.get_item({key_name: key})
new_response: Any = None

View file

@ -17,7 +17,12 @@ class MaxParallelRequestsHandler(CustomLogger):
pass
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
async def async_pre_call_hook(
self,

View file

@ -157,6 +157,12 @@ def is_port_in_use(port):
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option(
"--run_gunicorn",
default=False,
is_flag=True,
help="Starts proxy via gunicorn, instead of uvicorn (better for managing multiple workers)",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server(
host,
@ -186,6 +192,7 @@ def run_server(
use_queue,
health,
version,
run_gunicorn,
):
global feature_telemetry
args = locals()
@ -439,9 +446,9 @@ def run_server(
port = random.randint(1024, 49152)
from litellm.proxy.proxy_server import app
if os.name == "nt":
if run_gunicorn == False:
uvicorn.run(app, host=host, port=port) # run uvicorn
else:
elif run_gunicorn == True:
import gunicorn.app.base
# Gunicorn Application Class

View file

@ -69,11 +69,16 @@ litellm_settings:
success_callback: ['langfuse']
max_budget: 10 # global budget for proxy
budget_duration: 30d # global budget duration, will reset after 30d
default_key_generate_params:
max_budget: 1.5000
models: ["azure-gpt-3.5"]
duration: None
# cache: True
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
general_settings:
allow_user_auth: True
master_key: sk-1234
alerting: ["slack"]
alerting_threshold: 10 # sends alerts if requests hang for 2 seconds

View file

@ -76,6 +76,7 @@ from litellm.proxy.utils import (
get_logging_payload,
reset_budget,
hash_token,
html_form,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
@ -94,6 +95,7 @@ from fastapi import (
BackgroundTasks,
Header,
Response,
Form,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
@ -245,8 +247,6 @@ async def user_api_key_auth(
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
if isinstance(api_key, str):
assert api_key.startswith("sk-") # prevent token hashes from being used
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key)
@ -283,6 +283,10 @@ async def user_api_key_auth(
if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key)
if isinstance(
api_key, str
): # if generated token, make sure it starts with sk-.
assert api_key.startswith("sk-") # prevent token hashes from being used
if route.startswith("/config/") and not is_master_key_valid:
raise Exception(f"Only admin can modify config")
@ -292,6 +296,7 @@ async def user_api_key_auth(
raise Exception("No connected db.")
## check for cache hit (In-Memory Cache)
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
valid_token = user_api_key_cache.get_cache(key=api_key)
@ -304,10 +309,15 @@ async def user_api_key_auth(
)
elif custom_db_client is not None:
try:
valid_token = await custom_db_client.get_data(
key=api_key, table_name="key"
)
except:
# (Patch: For DynamoDB Backwards Compatibility)
valid_token = await custom_db_client.get_data(
key=original_api_key, table_name="key"
)
verbose_proxy_logger.debug(f"Token from db: {valid_token}")
elif valid_token is not None:
verbose_proxy_logger.debug(f"API Key Cache Hit!")
@ -1117,6 +1127,9 @@ class ProxyConfig:
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else:
verbose_proxy_logger.debug(
f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
)
setattr(litellm, key, value)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
@ -2385,6 +2398,26 @@ async def generate_key_fn(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=message
)
# check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None:
for elem in data:
key, value = elem
if value is None and key in [
"max_budget",
"user_id",
"team_id",
"max_parallel_requests",
"tpm_limit",
"rpm_limit",
"budget_duration",
]:
setattr(
data, key, litellm.default_key_generate_params.get(key, None)
)
elif key == "models" and value == []:
setattr(data, key, litellm.default_key_generate_params.get(key, []))
elif key == "metadata" and value == {}:
setattr(data, key, litellm.default_key_generate_params.get(key, {}))
data_json = data.json() # type: ignore
@ -2854,7 +2887,7 @@ async def user_auth(request: Request):
return "Email sent!"
@app.get("/google-login/key/generate", tags=["experimental"])
@app.get("/sso/key/generate", tags=["experimental"])
async def google_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting GOOGLE_REDIRECT_URI in .env
@ -2863,102 +2896,108 @@ async def google_login(request: Request):
Example:
"""
GOOGLE_REDIRECT_URI = os.getenv("PROXY_BASE_URL")
if GOOGLE_REDIRECT_URI is None:
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
redirect_url = os.getenv("PROXY_BASE_URL", None)
if redirect_url is None:
raise ProxyException(
message="PROXY_BASE_URL not set. Set it in .env file",
type="auth_error",
param="PROXY_BASE_URL",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if GOOGLE_REDIRECT_URI.endswith("/"):
GOOGLE_REDIRECT_URI += "google-callback"
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
GOOGLE_REDIRECT_URI += "/google-callback"
redirect_url += "/sso/callback"
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
if GOOGLE_CLIENT_ID is None:
GOOGLE_CLIENT_ID = (
"246483686424-clje5sggkjma26ilktj6qssakqhoon0m.apps.googleusercontent.com"
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {GOOGLE_REDIRECT_URI}\nGOOGLE_CLIENT_ID: {GOOGLE_CLIENT_ID}"
)
google_auth_url = f"https://accounts.google.com/o/oauth2/auth?client_id={GOOGLE_CLIENT_ID}&redirect_uri={GOOGLE_REDIRECT_URI}&response_type=code&scope=openid%20profile%20email"
return RedirectResponse(url=google_auth_url)
@app.get("/google-callback", tags=["experimental"], response_model=GenerateKeyResponse)
async def google_callback(code: str, request: Request):
import httpx
GOOGLE_REDIRECT_URI = os.getenv("PROXY_BASE_URL")
if GOOGLE_REDIRECT_URI is None:
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="PROXY_BASE_URL not set. Set it in .env file",
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="PROXY_BASE_URL",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Add "/google-callback"" to your callback URL
if GOOGLE_REDIRECT_URI.endswith("/"):
GOOGLE_REDIRECT_URI += "google-callback"
else:
GOOGLE_REDIRECT_URI += "/google-callback"
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
if GOOGLE_CLIENT_ID is None:
GOOGLE_CLIENT_ID = (
"246483686424-clje5sggkjma26ilktj6qssakqhoon0m.apps.googleusercontent.com"
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
if GOOGLE_CLIENT_SECRET is None:
GOOGLE_CLIENT_SECRET = "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"
verbose_proxy_logger.info(
f"/google-callback\n GOOGLE_REDIRECT_URI: {GOOGLE_REDIRECT_URI}\n GOOGLE_CLIENT_ID: {GOOGLE_CLIENT_ID}"
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
# Exchange code for access token
async with httpx.AsyncClient() as client:
token_url = f"https://oauth2.googleapis.com/token"
data = {
"code": code,
"client_id": GOOGLE_CLIENT_ID,
"client_secret": GOOGLE_CLIENT_SECRET,
"redirect_uri": GOOGLE_REDIRECT_URI,
"grant_type": "authorization_code",
}
response = await client.post(token_url, data=data)
# Process the response, extract user info, etc.
if response.status_code == 200:
access_token = response.json()["access_token"]
with google_sso:
return await google_sso.get_login_redirect()
# Fetch user info using the access token
async with httpx.AsyncClient() as client:
user_info_url = "https://www.googleapis.com/oauth2/v1/userinfo"
headers = {"Authorization": f"Bearer {access_token}"}
user_info_response = await client.get(user_info_url, headers=headers)
# Microsoft SSO Auth
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
# Process user info response
if user_info_response.status_code == 200:
user_info = user_info_response.json()
user_email = user_info.get("email")
user_name = user_info.get("name")
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# we can use user_email on litellm proxy now
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
else:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
# TODO: Handle user info as needed, for example, store it in a database, authenticate the user, etc.
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
form = await request.form()
username = str(form.get("username"))
password = form.get("password")
ui_username = os.getenv("UI_USERNAME")
ui_password = os.getenv("UI_PASSWORD")
if username == ui_username and password == ui_password:
user_id = username
response = await generate_key_helper_fn(
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_email, "team_id": "litellm-dashboard"} # type: ignore
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
# if user set LITELLM_UI_LINK in .env, use that
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
if litellm_ui_link_in_env is not None:
litellm_dashboard_ui = litellm_ui_link_in_env
litellm_dashboard_ui += (
"?userID="
+ user_id
@ -2968,16 +3007,108 @@ async def google_callback(code: str, request: Request):
+ os.getenv("PROXY_BASE_URL")
)
return RedirectResponse(url=litellm_dashboard_ui)
else:
# Handle user info retrieval error
raise HTTPException(
status_code=user_info_response.status_code,
detail=user_info_response.text,
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
redirect_url = os.getenv("PROXY_BASE_URL", None)
if redirect_url is None:
raise ProxyException(
message="PROXY_BASE_URL not set. Set it in .env file",
type="auth_error",
param="PROXY_BASE_URL",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
# Handle the error from the token exchange
raise HTTPException(status_code=response.status_code, detail=response.text)
redirect_url += "/sso/callback"
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
# User is Authe'd in - generate key for the UI to access Proxy
user_id = getattr(result, "email", None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
response = await generate_key_helper_fn(
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
# if user set LITELLM_UI_LINK in .env, use that
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
if litellm_ui_link_in_env is not None:
litellm_dashboard_ui = litellm_ui_link_in_env
litellm_dashboard_ui += (
"?userID="
+ user_id
+ "&accessToken="
+ key
+ "&proxyBaseUrl="
+ os.getenv("PROXY_BASE_URL")
)
return RedirectResponse(url=litellm_dashboard_ui)
@router.get(
@ -3589,6 +3720,8 @@ async def health_readiness():
cache_type = None
if litellm.cache is not None:
cache_type = litellm.cache.type
from litellm._version import version
if prisma_client is not None: # if db passed in, check if it's connected
if prisma_client.db.is_connected() == True:
response_object = {"db": "connected"}
@ -3597,6 +3730,7 @@ async def health_readiness():
"status": "healthy",
"db": "connected",
"cache": cache_type,
"litellm_version": version,
"success_callbacks": litellm.success_callback,
}
else:
@ -3604,6 +3738,7 @@ async def health_readiness():
"status": "healthy",
"db": "Not connected",
"cache": cache_type,
"litellm_version": version,
"success_callbacks": litellm.success_callback,
}
raise HTTPException(status_code=503, detail="Service Unhealthy")

View file

@ -21,6 +21,7 @@ from datetime import datetime, timedelta
def print_verbose(print_statement):
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa
@ -96,6 +97,7 @@ class ProxyLogging:
2. /embeddings
3. /image/generation
"""
print_verbose(f"Inside Proxy Logging Pre-call hook!")
### ALERTING ###
asyncio.create_task(self.response_taking_too_long(request_data=data))
@ -1035,7 +1037,7 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html):
print_verbose(f"SMTP Connection Init")
# Establish a secure connection with the SMTP server
with smtplib.SMTP(smtp_host, smtp_port) as server:
if os.getenv("SMTP_TLS", 'True') != "False":
if os.getenv("SMTP_TLS", "True") != "False":
server.starttls()
# Login to your email account
@ -1206,3 +1208,67 @@ async def reset_budget(prisma_client: PrismaClient):
await prisma_client.update_data(
query_type="update_many", data_list=users_to_reset, table_name="user"
)
# LiteLLM Admin UI - Non SSO Login
html_form = """
<!DOCTYPE html>
<html>
<head>
<title>LiteLLM Login</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}
form {
background-color: #fff;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
label {
display: block;
margin-bottom: 8px;
}
input {
width: 100%;
padding: 8px;
margin-bottom: 16px;
box-sizing: border-box;
border: 1px solid #ccc;
border-radius: 4px;
}
input[type="submit"] {
background-color: #4caf50;
color: #fff;
cursor: pointer;
}
input[type="submit"]:hover {
background-color: #45a049;
}
</style>
</head>
<body>
<form action="/login" method="post">
<h2>LiteLLM Login</h2>
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
<input type="submit" value="Submit">
</form>
</body>
</html>
"""

View file

@ -1411,6 +1411,12 @@ class Router:
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name:
if api_base is None:
raise ValueError(
@ -1610,6 +1616,7 @@ class Router:
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
@ -1630,6 +1637,7 @@ class Router:
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
@ -1651,6 +1659,7 @@ class Router:
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
@ -1672,6 +1681,7 @@ class Router:
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(

View file

@ -70,18 +70,16 @@ def test_completion_with_empty_model():
def test_completion_invalid_param_cohere():
try:
response = completion(model="command-nightly", messages=messages, top_p=1)
print(f"response: {response}")
litellm.set_verbose = True
response = completion(model="command-nightly", messages=messages, seed=12)
pytest.fail(f"This should have failed cohere does not support `seed` parameter")
except Exception as e:
if "Unsupported parameters passed: top_p" in str(e):
if " cohere does not support parameters: {'seed': 12}" in str(e):
pass
else:
pytest.fail(f"An error occurred {e}")
# test_completion_invalid_param_cohere()
def test_completion_function_call_cohere():
try:
response = completion(

View file

@ -515,7 +515,7 @@ def hf_test_completion_tgi():
# hf_test_error_logs()
# def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky
# def test_completion_cohere(): # commenting out,for now as the cohere endpoint is being flaky
# try:
# litellm.CohereConfig(max_tokens=10, stop_sequences=["a"])
# response = completion(
@ -569,6 +569,22 @@ def test_completion_openai():
# test_completion_openai()
def test_completion_openai_organization():
try:
litellm.set_verbose = True
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
assert "No such organization: org-ikDc4ex8NB" in str(e)
except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")
def test_completion_text_openai():
try:
# litellm.set_verbose = True

View file

@ -302,6 +302,25 @@ def test_bedrock_embedding_cohere():
# test_bedrock_embedding_cohere()
def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
litellm.set_verbose = True
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
):
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
):
litellm.embedding(
model="amazon.titan-embed-text-v1",
input=[1],
)
# comment out hf tests - since hf endpoints are unstable
def test_hf_embedding():
try:

View file

@ -472,3 +472,32 @@ def test_call_with_key_over_budget_stream(custom_db_client):
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))
def test_dynamo_db_migration(custom_db_client):
# Tests the temporary patch we have in place
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "user_custom_auth", None)
try:
async def test():
bearer_token = (
"Bearer " + "sk-elJDL2pOEjcAuC7zD4psAg"
) # this works with ishaan's db, it's a never expiring key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return b'{"model": "azure-models"}'
request.body = return_body
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
asyncio.run(test())
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -1188,3 +1188,27 @@ async def test_key_name_set(prisma_client):
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")
@pytest.mark.asyncio()
async def test_default_key_params(prisma_client):
"""
- create key
- get key info
- assert key_name is not null
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
litellm.default_key_generate_params = {"max_budget": 0.000122}
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert result["info"]["max_budget"] == 0.000122
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")

View file

@ -456,6 +456,7 @@ async def test_streaming_router_call():
@pytest.mark.asyncio
async def test_streaming_router_tpm_limit():
litellm.set_verbose = True
model_list = [
{
"model_name": "azure-model",
@ -520,7 +521,7 @@ async def test_streaming_router_tpm_limit():
)
async for chunk in response:
continue
await asyncio.sleep(1) # success is done in a separate thread
await asyncio.sleep(5) # success is done in a separate thread
try:
await parallel_request_handler.async_pre_call_hook(

View file

@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_openai_with_organization():
try:
print("Testing OpenAI with organization")
model_list = [
{
"model_name": "openai-bad-org",
"litellm_params": {
"model": "gpt-3.5-turbo",
"organization": "org-ikDc4ex8NB",
},
},
{
"model_name": "openai-good-org",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
]
router = Router(model_list=model_list)
print(router.model_list)
print(router.model_list[0])
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -2929,32 +2929,10 @@ def cost_per_token(
model_with_provider_and_region in model_cost_ref
): # use region based pricing, if it's available
model_with_provider = model_with_provider_and_region
if model_with_provider in model_cost_ref:
model = model_with_provider
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
if model_with_provider in model_cost_ref:
print_verbose(
f"Success: model={model_with_provider} in model_cost_map - {model_cost_ref[model_with_provider]}"
)
print_verbose(
f"applying cost={model_cost_ref[model_with_provider].get('input_cost_per_token', None)} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
print_verbose(
f"calculated prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}"
)
print_verbose(
f"applying cost={model_cost_ref[model_with_provider].get('output_cost_per_token', None)} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
print_verbose(
f"calculated completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
if model in model_cost_ref:
print_verbose(f"Success: model={model} in model_cost_map")
print_verbose(
@ -7509,7 +7487,10 @@ class CustomStreamWrapper:
logprobs = None
original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0:
if str_line.choices[0].delta.content is not None:
if (
str_line.choices[0].delta is not None
and str_line.choices[0].delta.content is not None
):
text = str_line.choices[0].delta.content
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
original_chunk = str_line

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.20.2"
version = "1.20.6"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.20.2"
version = "1.20.6"
version_files = [
"pyproject.toml:^version"
]

View file

@ -70,10 +70,11 @@ async def test_health_readiness():
url = "http://0.0.0.0:4000/health/readiness"
async with session.get(url) as response:
status = response.status
response_text = await response.text()
response_json = await response.json()
print(response_text)
print()
print(response_json)
assert "litellm_version" in response_json
assert "status" in response_json
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")

View file

@ -1,20 +1,20 @@
# Use official Python base image
FROM python:3.9.12
# Use an official Node.js image as the base image
FROM node:18-alpine
EXPOSE 8501
# Set the working directory in the container
# Set the working directory inside the container
WORKDIR /app
# Copy the requirements.txt file to the container
COPY requirements.txt .
# Copy package.json and package-lock.json to the working directory
COPY ./litellm-dashboard/package*.json ./
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Install dependencies
RUN npm install
# Copy the entire project directory to the container
COPY admin.py .
# Copy the rest of the application code to the working directory
COPY ./litellm-dashboard .
# Set the entrypoint command to run admin.py with Streamlit
ENTRYPOINT [ "streamlit", "run"]
CMD ["admin.py"]
# Expose the port that the Next.js app will run on
EXPOSE 3000
# Start the Next.js app
CMD ["npm", "run", "dev"]

View file

@ -1,11 +1,14 @@
"use client";
import React, { useState, useEffect, useRef } from "react";
import { Button, TextInput, Grid, Col } from "@tremor/react";
import { message } from "antd";
import { Card, Metric, Text } from "@tremor/react";
import { Button as Button2, Modal, Form, Input, InputNumber, Select, message } from "antd";
import { keyCreateCall } from "./networking";
// Define the props type
const { Option } = Select;
interface CreateKeyProps {
userID: string;
accessToken: string;
@ -14,8 +17,6 @@ interface CreateKeyProps {
setData: React.Dispatch<React.SetStateAction<any[] | null>>;
}
import { Modal, Button as Button2 } from "antd";
const CreateKey: React.FC<CreateKeyProps> = ({
userID,
accessToken,
@ -23,49 +24,106 @@ const CreateKey: React.FC<CreateKeyProps> = ({
data,
setData,
}) => {
const [form] = Form.useForm();
const [isModalVisible, setIsModalVisible] = useState(false);
const [apiKey, setApiKey] = useState(null);
const handleOk = () => {
// Handle the OK action
console.log("OK Clicked");
setIsModalVisible(false);
form.resetFields();
};
const handleCancel = () => {
// Handle the cancel action or closing the modal
console.log("Modal closed");
setIsModalVisible(false);
setApiKey(null);
form.resetFields();
};
const handleCreate = async () => {
if (data == null) {
return;
}
const handleCreate = async (formValues: Record<string, any>) => {
try {
message.info("Making API Call");
// Check if "models" exists and is not an empty string
if (formValues.models && formValues.models.trim() !== '') {
// Format the "models" field as an array
formValues.models = formValues.models.split(',').map((model: string) => model.trim());
} else {
// If "models" is undefined or an empty string, set it to an empty array
formValues.models = [];
}
setIsModalVisible(true);
const response = await keyCreateCall(proxyBaseUrl, accessToken, userID);
// Successfully completed the deletion. Update the state to trigger a rerender.
setData([...data, response]);
const response = await keyCreateCall(proxyBaseUrl, accessToken, userID, formValues);
setData((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
setApiKey(response["key"]);
message.success("API Key Created");
form.resetFields();
} catch (error) {
console.error("Error deleting the key:", error);
// Handle any error situations, such as displaying an error message to the user.
console.error("Error creating the key:", error);
}
};
return (
<div>
<Button className="mx-auto" onClick={handleCreate}>
<Button className="mx-auto" onClick={() => setIsModalVisible(true)}>
+ Create New Key
</Button>
<Modal
title="Save your key"
open={isModalVisible}
title="Create Key"
visible={isModalVisible}
width={800}
footer={null}
onOk={handleOk}
onCancel={handleCancel}
>
<Form form={form} onFinish={handleCreate} labelCol={{ span: 6 }} wrapperCol={{ span: 16 }} labelAlign="left">
<Form.Item
label="Key Name"
name="key_alias"
>
<Input />
</Form.Item>
<Form.Item
label="Models (Comma Separated). Eg: gpt-3.5-turbo,gpt-4"
name="models"
>
<Input placeholder="gpt-4,gpt-3.5-turbo" />
</Form.Item>
<Form.Item
label="Max Budget (USD)"
name="max_budget"
>
<InputNumber step={0.01} precision={2} width={200}/>
</Form.Item>
<Form.Item
label="Duration (eg: 30s, 30h, 30d)"
name="duration"
>
<Input />
</Form.Item>
<Form.Item
label="Metadata"
name="metadata"
>
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
</Form.Item>
<div style={{ textAlign: 'right', marginTop: '10px' }}>
<Button2 htmlType="submit">
Create Key
</Button2>
</div>
</Form>
</Modal>
{apiKey && (
<Modal
title="Save your key"
visible={isModalVisible}
onOk={handleOk}
onCancel={handleCancel}
footer={null}
>
<Grid numItems={1} className="gap-2 w-full">
<Col numColSpan={1}>
@ -85,6 +143,7 @@ const CreateKey: React.FC<CreateKeyProps> = ({
</Col>
</Grid>
</Modal>
)}
</div>
);
};

View file

@ -3,11 +3,13 @@
*/
export const keyCreateCall = async (
proxyBaseUrl: String,
accessToken: String,
userID: String
proxyBaseUrl: string,
accessToken: string,
userID: string,
formValues: Record<string, any> // Assuming formValues is an object
) => {
try {
console.log("Form Values in keyCreateCall:", formValues); // Log the form values before making the API call
const response = await fetch(`${proxyBaseUrl}/key/generate`, {
method: "POST",
headers: {
@ -15,18 +17,19 @@ export const keyCreateCall = async (
"Content-Type": "application/json",
},
body: JSON.stringify({
team_id: "core-infra-4",
max_budget: 10,
user_id: userID,
...formValues, // Include formValues in the request body
}),
});
if (!response.ok) {
const errorData = await response.json();
console.error("Error response from the server:", errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log(data);
console.log("API Response:", data);
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
@ -35,6 +38,7 @@ export const keyCreateCall = async (
}
};
export const keyDeleteCall = async (
proxyBaseUrl: String,
accessToken: String,

View file

@ -43,11 +43,10 @@ const UserDashboard = () => {
);
}
else if (userID == null || accessToken == null) {
// redirect to page: ProxyBaseUrl/google-login/key/generate
const baseUrl = proxyBaseUrl.endsWith('/') ? proxyBaseUrl : proxyBaseUrl + '/';
// Now you can construct the full URL
const url = `${baseUrl}google-login/key/generate`;
const url = `${baseUrl}sso/key/generate`;
window.location.href = url;

View file

@ -58,7 +58,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
<TableHead>
<TableRow>
<TableHeaderCell>Secret Key</TableHeaderCell>
<TableHeaderCell>Spend</TableHeaderCell>
<TableHeaderCell>Spend (USD)</TableHeaderCell>
<TableHeaderCell>Key Budget (USD)</TableHeaderCell>
<TableHeaderCell>Expires</TableHeaderCell>
</TableRow>
</TableHead>
@ -68,11 +69,24 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
return (
<TableRow key={item.token}>
<TableCell>
{item.key_name != null ? (
<Text>{item.key_name}</Text>
) : (
<Text>{item.token}</Text>
)
}
</TableCell>
<TableCell>
<Text>{item.spend}</Text>
</TableCell>
<TableCell>
{item.max_budget != null ? (
<Text>{item.max_budget}</Text>
) : (
<Text>Unlimited Budget</Text>
)
}
</TableCell>
<TableCell>
{item.expires != null ? (
<Text>{item.expires}</Text>