Merge branch 'main' into litellm_ssl_caching_fix

This commit is contained in:
Krish Dholakia 2024-04-19 17:20:27 -07:00 committed by GitHub
commit a9dc93e860
25 changed files with 288 additions and 63 deletions

View file

@ -48,6 +48,8 @@ We support ALL Groq models, just set `groq/` as a prefix when sending completion
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| llama3-8b-8192 | `completion(model="groq/llama3-8b-8192", messages)` |
| llama3-70b-8192 | `completion(model="groq/llama3-70b-8192", messages)` |
| llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` | | llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` |
| mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` | | mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` |
| gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` | | gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` |

View file

@ -16,11 +16,24 @@ dotenv.load_dotenv()
if set_verbose == True: if set_verbose == True:
_turn_on_debug() _turn_on_debug()
############################################# #############################################
### Callbacks /Logging / Success / Failure Handlers ###
input_callback: List[Union[str, Callable]] = [] input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_langfuse_default_tags: Optional[
List[
Literal[
"user_api_key_alias",
"user_api_key_user_id",
"user_api_key_user_email",
"user_api_key_team_alias",
"semantic-similarity",
"proxy_base_url",
]
]
] = None
_async_input_callback: List[Callable] = ( _async_input_callback: List[Callable] = (
[] []
) # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
@ -32,6 +45,8 @@ _async_failure_callback: List[Callable] = (
) # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
## end of callbacks #############
email: Optional[str] = ( email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
) )

View file

@ -1281,7 +1281,7 @@ class DualCache(BaseCache):
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs) self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items(): for key, value in redis_result.items():
result[sublist_keys.index(key)] = value result[keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}") print_verbose(f"async batch get cache: cache result: {result}")
return result return result
@ -1331,7 +1331,6 @@ class DualCache(BaseCache):
keys, **kwargs keys, **kwargs
) )
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False: if None in result and self.redis_cache is not None and local_only == False:
@ -1355,9 +1354,9 @@ class DualCache(BaseCache):
key, redis_result[key], **kwargs key, redis_result[key], **kwargs
) )
for key, value in redis_result.items(): for key, value in redis_result.items():
result[sublist_keys.index(key)] = value index = keys.index(key)
result[index] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View file

@ -280,13 +280,13 @@ class LangFuseLogger:
clean_metadata = {} clean_metadata = {}
if isinstance(metadata, dict): if isinstance(metadata, dict):
for key, value in metadata.items(): for key, value in metadata.items():
# generate langfuse tags
if key in [ # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
"user_api_key", if (
"user_api_key_user_id", litellm._langfuse_default_tags is not None
"user_api_key_team_id", and isinstance(litellm._langfuse_default_tags, list)
"semantic-similarity", and key in litellm._langfuse_default_tags
]: ):
tags.append(f"{key}:{value}") tags.append(f"{key}:{value}")
# clean litellm metadata before logging # clean litellm metadata before logging
@ -300,6 +300,15 @@ class LangFuseLogger:
else: else:
clean_metadata[key] = value clean_metadata[key] = value
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and "proxy_base_url" in litellm._langfuse_default_tags
):
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
if proxy_base_url is not None:
tags.append(f"proxy_base_url:{proxy_base_url}")
api_base = litellm_params.get("api_base", None) api_base = litellm_params.get("api_base", None)
if api_base: if api_base:
clean_metadata["api_base"] = api_base clean_metadata["api_base"] = api_base

View file

@ -151,7 +151,6 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_success_calls += 1 self.mock_testing_success_calls += 1
print(f"LOGS SUCCESSFUL CALL TO PROMETHEUS - payload={payload}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:

View file

@ -258,8 +258,9 @@ class AnthropicChatCompletion(BaseLLM):
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
data["stream"] = True
response = await self.async_handler.post( response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data), stream=True
) )
if response.status_code != 200: if response.status_code != 200:

View file

@ -41,13 +41,12 @@ class AsyncHTTPHandler:
data: Optional[Union[dict, str]] = None, # type: ignore data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False,
): ):
response = await self.client.post( req = self.client.build_request(
url, "POST", url, data=data, params=params, headers=headers # type: ignore
data=data, # type: ignore
params=params,
headers=headers,
) )
response = await self.client.send(req, stream=stream)
return response return response
def __del__(self) -> None: def __del__(self) -> None:

View file

@ -735,6 +735,26 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true
}, },
"groq/llama3-8b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-70b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000064,
"output_cost_per_token": 0.00000080,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/mixtral-8x7b-32768": { "groq/mixtral-8x7b-32768": {
"max_tokens": 32768, "max_tokens": 32768,
"max_input_tokens": 32768, "max_input_tokens": 32768,

View file

@ -7,6 +7,14 @@ model_list:
# api_base: http://0.0.0.0:8080 # api_base: http://0.0.0.0:8080
stream_timeout: 0.001 stream_timeout: 0.001
rpm: 10 rpm: 10
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model-2
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
# api_base: http://0.0.0.0:8080
stream_timeout: 0.001
rpm: 10
- litellm_params: - litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE

View file

@ -792,6 +792,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
""" """
team_spend: Optional[float] = None team_spend: Optional[float] = None
team_alias: Optional[str] = None
team_tpm_limit: Optional[int] = None team_tpm_limit: Optional[int] = None
team_rpm_limit: Optional[int] = None team_rpm_limit: Optional[int] = None
team_max_budget: Optional[float] = None team_max_budget: Optional[float] = None

View file

@ -14,3 +14,7 @@ general_settings:
store_model_in_db: true store_model_in_db: true
master_key: sk-1234 master_key: sk-1234
alerting: ["slack"] alerting: ["slack"]
litellm_settings:
success_callback: ["langfuse"]
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]

View file

@ -3361,6 +3361,9 @@ async def completion(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
_headers = dict(request.headers) _headers = dict(request.headers)
_headers.pop( _headers.pop(
"authorization", None "authorization", None
@ -3562,6 +3565,9 @@ async def chat_completion(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers) _headers = dict(request.headers)
_headers.pop( _headers.pop(
@ -3793,6 +3799,9 @@ async def embeddings(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ### ### TEAM-SPECIFIC PARAMS ###
@ -3971,6 +3980,9 @@ async def image_generation(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ### ### TEAM-SPECIFIC PARAMS ###
@ -4127,6 +4139,9 @@ async def audio_transcriptions(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename data["metadata"]["file_name"] = file.filename
@ -4302,6 +4317,9 @@ async def moderations(
data["metadata"]["user_api_key_team_id"] = getattr( data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ### ### TEAM-SPECIFIC PARAMS ###

View file

@ -1186,6 +1186,7 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit, t.rpm_limit AS team_rpm_limit,
t.models AS team_models, t.models AS team_models,
t.blocked AS team_blocked, t.blocked AS team_blocked,
t.team_alias AS team_alias,
m.aliases as team_model_aliases m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id

View file

@ -187,6 +187,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return deployment return deployment
except Exception as e: except Exception as e:
if isinstance(e, litellm.RateLimitError): if isinstance(e, litellm.RateLimitError):

View file

@ -33,6 +33,51 @@ def generate_random_word(length=4):
messages = [{"role": "user", "content": "who is ishaan 5222"}] messages = [{"role": "user", "content": "who is ishaan 5222"}]
@pytest.mark.asyncio
async def test_dual_cache_async_batch_get_cache():
"""
Unit testing for Dual Cache async_batch_get_cache()
- 2 item query
- in_memory result has a partial hit (1/2)
- hit redis for the other -> expect to return None
- expect result = [in_memory_result, None]
"""
from litellm.caching import DualCache, InMemoryCache, RedisCache
in_memory_cache = InMemoryCache()
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
def test_dual_cache_batch_get_cache():
"""
Unit testing for Dual Cache batch_get_cache()
- 2 item query
- in_memory result has a partial hit (1/2)
- hit redis for the other -> expect to return None
- expect result = [in_memory_result, None]
"""
from litellm.caching import DualCache, InMemoryCache, RedisCache
in_memory_cache = InMemoryCache()
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
result = dual_cache.batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
# @pytest.mark.skip(reason="") # @pytest.mark.skip(reason="")
def test_caching_dynamic_args(): # test in memory cache def test_caching_dynamic_args(): # test in memory cache
try: try:

View file

@ -221,6 +221,9 @@ def test_parallel_function_call_stream():
# test_parallel_function_call_stream() # test_parallel_function_call_stream()
@pytest.mark.skip(
reason="Flaky test. Groq function calling is not reliable for ci/cd testing."
)
def test_groq_parallel_function_call(): def test_groq_parallel_function_call():
litellm.set_verbose = True litellm.set_verbose = True
try: try:

View file

@ -23,6 +23,10 @@ from litellm.caching import DualCache
### UNIT TESTS FOR TPM/RPM ROUTING ### ### UNIT TESTS FOR TPM/RPM ROUTING ###
"""
- Given 2 deployments, make sure it's shuffling deployments correctly.
"""
def test_tpm_rpm_updated(): def test_tpm_rpm_updated():
test_cache = DualCache() test_cache = DualCache()

View file

@ -735,6 +735,26 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true
}, },
"groq/llama3-8b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-70b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000064,
"output_cost_per_token": 0.00000080,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/mixtral-8x7b-32768": { "groq/mixtral-8x7b-32768": {
"max_tokens": 32768, "max_tokens": 32768,
"max_input_tokens": 32768, "max_input_tokens": 32768,

View file

@ -55,6 +55,20 @@ model_list:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
stream_timeout: 0.001 stream_timeout: 0.001
rpm: 1 rpm: 1
- model_name: fake-openai-endpoint-3
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
stream_timeout: 0.001
rpm: 10
- model_name: fake-openai-endpoint-3
litellm_params:
model: openai/my-fake-model-2
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
stream_timeout: 0.001
rpm: 10
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: openai/* model: openai/*

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.35.15" version = "1.35.16"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.35.15" version = "1.35.16"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -102,6 +102,47 @@ async def chat_completion(session, key, model="gpt-4"):
return await response.json() return await response.json()
async def chat_completion_with_headers(session, key, model="gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
response_header_check(
response
) # calling the function to check response headers
raw_headers = response.raw_headers
raw_headers_json = {}
for (
item
) in (
response.raw_headers
): # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), )
raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8")
return raw_headers_json
async def completion(session, key): async def completion(session, key):
url = "http://0.0.0.0:4000/completions" url = "http://0.0.0.0:4000/completions"
headers = { headers = {
@ -222,6 +263,36 @@ async def test_chat_completion_ratelimit():
pass pass
@pytest.mark.asyncio
async def test_chat_completion_different_deployments():
"""
- call model group with 2 deployments
- make 5 calls
- expect 2 unique deployments
"""
async with aiohttp.ClientSession() as session:
# key_gen = await generate_key(session=session)
key = "sk-1234"
results = []
for _ in range(5):
results.append(
await chat_completion_with_headers(
session=session, key=key, model="fake-openai-endpoint-3"
)
)
try:
print(f"results: {results}")
init_model_id = results[0]["x-litellm-model-id"]
deployments_shuffled = False
for result in results[1:]:
if init_model_id != result["x-litellm-model-id"]:
deployments_shuffled = True
if deployments_shuffled == False:
pytest.fail("Expected at least 1 shuffled call")
except Exception as e:
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_old_key(): async def test_chat_completion_old_key():
""" """

View file

@ -145,6 +145,7 @@ const CreateKeyPage = () => {
userRole={userRole} userRole={userRole}
token={token} token={token}
keys={keys} keys={keys}
teams={teams}
accessToken={accessToken} accessToken={accessToken}
setKeys={setKeys} setKeys={setKeys}
/> />

View file

@ -1,15 +1,16 @@
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import { Button, Modal, Form, Input, message, Select, InputNumber } from "antd"; import { Button, Modal, Form, Input, message, Select, InputNumber } from "antd";
import { Button as Button2 } from "@tremor/react"; import { Button as Button2, Text } from "@tremor/react";
import { userCreateCall, modelAvailableCall } from "./networking"; import { userCreateCall, modelAvailableCall } from "./networking";
const { Option } = Select; const { Option } = Select;
interface CreateuserProps { interface CreateuserProps {
userID: string; userID: string;
accessToken: string; accessToken: string;
teams: any[] | null;
} }
const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => { const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken, teams }) => {
const [form] = Form.useForm(); const [form] = Form.useForm();
const [isModalVisible, setIsModalVisible] = useState(false); const [isModalVisible, setIsModalVisible] = useState(false);
const [apiuser, setApiuser] = useState<string | null>(null); const [apiuser, setApiuser] = useState<string | null>(null);
@ -59,7 +60,7 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
message.info("Making API Call"); message.info("Making API Call");
setIsModalVisible(true); setIsModalVisible(true);
console.log("formValues in create user:", formValues); console.log("formValues in create user:", formValues);
const response = await userCreateCall(accessToken, userID, formValues); const response = await userCreateCall(accessToken, null, formValues);
console.log("user create Response:", response); console.log("user create Response:", response);
setApiuser(response["key"]); setApiuser(response["key"]);
message.success("API user Created"); message.success("API user Created");
@ -73,16 +74,18 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
return ( return (
<div> <div>
<Button2 className="mx-auto" onClick={() => setIsModalVisible(true)}> <Button2 className="mx-auto" onClick={() => setIsModalVisible(true)}>
+ Create New User + Invite User
</Button2> </Button2>
<Modal <Modal
title="Create User" title="Invite User"
visible={isModalVisible} visible={isModalVisible}
width={800} width={800}
footer={null} footer={null}
onOk={handleOk} onOk={handleOk}
onCancel={handleCancel} onCancel={handleCancel}
> >
<Text className="mb-1">Invite a user to login to the Admin UI and create Keys</Text>
<Text className="mb-6"><b>Note: SSO Setup Required for this</b></Text>
<Form <Form
form={form} form={form}
onFinish={handleCreate} onFinish={handleCreate}
@ -90,38 +93,27 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
wrapperCol={{ span: 16 }} wrapperCol={{ span: 16 }}
labelAlign="left" labelAlign="left"
> >
<Form.Item label="User ID" name="user_id"> <Form.Item label="User Email" name="user_email">
<Input placeholder="Enter User ID" /> <Input placeholder="Enter User Email" />
</Form.Item> </Form.Item>
<Form.Item label="Team ID" name="team_id"> <Form.Item label="Team ID" name="team_id">
<Input placeholder="ai_team" /> <Select
</Form.Item> placeholder="Select Team ID"
<Form.Item label="Models" name="models">
<Select
mode="multiple"
placeholder="Select models"
style={{ width: "100%" }} style={{ width: "100%" }}
> >
{userModels.map((model) => ( {teams ? (
<Option key={model} value={model}> teams.map((team: any) => (
{model} <Option key={team.team_id} value={team.team_id}>
{team.team_alias}
</Option>
))
) : (
<Option key="default" value={null}>
Default Team
</Option> </Option>
))} )}
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item label="Max Budget (USD)" name="max_budget">
<InputNumber step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item label="Tokens per minute Limit (TPM)" name="tpm_limit">
<InputNumber step={1} width={400} />
</Form.Item>
<Form.Item label="Requests per minute Limit (RPM)" name="rpm_limit">
<InputNumber step={1} width={400} />
</Form.Item>
<Form.Item label="Duration (eg: 30s, 30h, 30d)" name="duration">
<Input />
</Form.Item>
<Form.Item label="Metadata" name="metadata"> <Form.Item label="Metadata" name="metadata">
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" /> <Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
</Form.Item> </Form.Item>
@ -132,23 +124,19 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
</Modal> </Modal>
{apiuser && ( {apiuser && (
<Modal <Modal
title="Save Your User" title="User Created Successfully"
visible={isModalVisible} visible={isModalVisible}
onOk={handleOk} onOk={handleOk}
onCancel={handleCancel} onCancel={handleCancel}
footer={null} footer={null}
> >
<p> <p>
Please save this secret user somewhere safe and accessible. For User has been created to access your proxy. Please Ask them to Log In.
security reasons, <b>you will not be able to view it again</b>{" "}
through your LiteLLM account. If you lose this secret user, you will
need to generate a new one.
</p>
<p>
{apiuser != null
? `API user: ${apiuser}`
: "User being created, this might take 30s"}
</p> </p>
<br></br>
<p><b>Note: This Feature is only supported through SSO on the Admin UI</b></p>
</Modal> </Modal>
)} )}
</div> </div>

View file

@ -158,7 +158,7 @@ export const keyCreateCall = async (
export const userCreateCall = async ( export const userCreateCall = async (
accessToken: string, accessToken: string,
userID: string, userID: string | null,
formValues: Record<string, any> // Assuming formValues is an object formValues: Record<string, any> // Assuming formValues is an object
) => { ) => {
try { try {

View file

@ -36,6 +36,7 @@ interface ViewUserDashboardProps {
keys: any[] | null; keys: any[] | null;
userRole: string | null; userRole: string | null;
userID: string | null; userID: string | null;
teams: any[] | null;
setKeys: React.Dispatch<React.SetStateAction<Object[] | null>>; setKeys: React.Dispatch<React.SetStateAction<Object[] | null>>;
} }
@ -45,6 +46,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
keys, keys,
userRole, userRole,
userID, userID,
teams,
setKeys, setKeys,
}) => { }) => {
const [userData, setUserData] = useState<null | any[]>(null); const [userData, setUserData] = useState<null | any[]>(null);
@ -151,7 +153,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
return ( return (
<div style={{ width: "100%" }}> <div style={{ width: "100%" }}>
<Grid className="gap-2 p-2 h-[75vh] w-full mt-8"> <Grid className="gap-2 p-2 h-[75vh] w-full mt-8">
<CreateUser userID={userID} accessToken={accessToken} /> <CreateUser userID={userID} accessToken={accessToken} teams={teams}/>
<Card className="w-full mx-auto flex-auto overflow-y-auto max-h-[50vh] mb-4"> <Card className="w-full mx-auto flex-auto overflow-y-auto max-h-[50vh] mb-4">
<TabGroup> <TabGroup>
<TabList variant="line" defaultValue="1"> <TabList variant="line" defaultValue="1">