Merge branch 'main' into litellm_block_unblock_user_api

This commit is contained in:
Krish Dholakia 2024-02-24 11:43:16 -08:00 committed by GitHub
commit aaf086c0a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 771 additions and 438 deletions

View file

@ -121,7 +121,7 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
# OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy)) # OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy))
Track spend across multiple projects/people Set Budgets & Rate limits across multiple projects
The proxy provides: The proxy provides:
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth) 1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
@ -163,7 +163,7 @@ print(response)
UI on `/ui` on your proxy server UI on `/ui` on your proxy server
![ui_3](https://github.com/BerriAI/litellm/assets/29436595/47c97d5e-b9be-4839-b28c-43d7f4f10033) ![ui_3](https://github.com/BerriAI/litellm/assets/29436595/47c97d5e-b9be-4839-b28c-43d7f4f10033)
Track Spend, Set budgets and create virtual keys for the proxy Set budgets and rate limits across multiple projects
`POST /key/generate` `POST /key/generate`
### Request ### Request

View file

@ -165,6 +165,7 @@ s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[Dict] = None
default_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####

View file

@ -235,6 +235,9 @@ class LangFuseLogger:
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3") supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_completion_start_time = Version(
langfuse.version.__version__
) >= Version("2.7.3")
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
@ -308,6 +311,11 @@ class LangFuseLogger:
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["statusMessage"] = output generation_params["statusMessage"] = output
if supports_completion_start_time:
generation_params["completion_start_time"] = kwargs.get(
"completion_start_time", None
)
trace.generation(**generation_params) trace.generation(**generation_params)
except Exception as e: except Exception as e:
print(f"Langfuse Layer Error - {traceback.format_exc()}") print(f"Langfuse Layer Error - {traceback.format_exc()}")

View file

@ -575,6 +575,14 @@ class Huggingface(BaseLLM):
response = await client.post(url=api_base, json=data, headers=headers) response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
if "error" in response_json:
raise HuggingfaceError(
status_code=response.status_code,
message=response_json["error"],
request=response.request,
response=response,
)
else:
raise HuggingfaceError( raise HuggingfaceError(
status_code=response.status_code, status_code=response.status_code,
message=response.text, message=response.text,
@ -595,6 +603,8 @@ class Huggingface(BaseLLM):
except Exception as e: except Exception as e:
if isinstance(e, httpx.TimeoutException): if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error") raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif isinstance(e, HuggingfaceError):
raise e
elif response is not None and hasattr(response, "text"): elif response is not None and hasattr(response, "text"):
raise HuggingfaceError( raise HuggingfaceError(
status_code=500, status_code=500,

View file

@ -730,6 +730,8 @@ async def user_api_key_auth(
"/user", "/user",
"/model/info", "/model/info",
"/v2/model/info", "/v2/model/info",
"/models",
"/v1/models",
] ]
# check if the current route startswith any of the allowed routes # check if the current route startswith any of the allowed routes
if ( if (
@ -1758,6 +1760,7 @@ async def generate_key_helper_fn(
allowed_cache_controls: Optional[list] = [], allowed_cache_controls: Optional[list] = [],
permissions: Optional[dict] = {}, permissions: Optional[dict] = {},
model_max_budget: Optional[dict] = {}, model_max_budget: Optional[dict] = {},
table_name: Optional[Literal["key", "user"]] = None,
): ):
global prisma_client, custom_db_client, user_api_key_cache global prisma_client, custom_db_client, user_api_key_cache
@ -1884,8 +1887,10 @@ async def generate_key_helper_fn(
table_name="user", table_name="user",
update_key_values=update_key_values, update_key_values=update_key_values,
) )
if user_id == litellm_proxy_budget_name: if user_id == litellm_proxy_budget_name or (
# do not create a key for litellm_proxy_budget_name table_name is not None and table_name == "user"
):
# do not create a key for litellm_proxy_budget_name or if table name is set to just 'user'
# we only need to ensure this exists in the user table # we only need to ensure this exists in the user table
# the LiteLLM_VerificationToken table will increase in size if we don't do this check # the LiteLLM_VerificationToken table will increase in size if we don't do this check
return key_data return key_data
@ -2440,7 +2445,7 @@ async def completion(
) )
traceback.print_exc() traceback.print_exc()
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"), type=getattr(e, "type", "None"),
@ -5548,27 +5553,50 @@ async def auth_callback(request: Request):
user_id_models: List = [] user_id_models: List = []
# User might not be already created on first generation of key # User might not be already created on first generation of key
# But if it is, we want its models preferences # But if it is, we want their models preferences
try: default_ui_key_values = {
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
except Exception as e:
pass
response = await generate_key_helper_fn(
**{
"duration": "1hr", "duration": "1hr",
"key_max_budget": 0.01, "key_max_budget": 0.01,
"models": user_id_models,
"aliases": {}, "aliases": {},
"config": {}, "config": {},
"spend": 0, "spend": 0,
"user_id": user_id,
"team_id": "litellm-dashboard", "team_id": "litellm-dashboard",
}
user_defined_values = {
"models": user_id_models,
"user_id": user_id,
"user_email": user_email, "user_email": user_email,
} # type: ignore }
try:
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
)
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", []),
"user_id": getattr(user_info, "user_id", user_id),
"user_email": getattr(user_info, "user_id", user_email),
}
elif litellm.default_user_params is not None and isinstance(
litellm.default_user_params, dict
):
user_defined_values = {
"models": litellm.default_user_params.get("models", user_id_models),
"user_id": litellm.default_user_params.get("user_id", user_id),
"user_email": litellm.default_user_params.get(
"user_email", user_email
),
}
except Exception as e:
pass
verbose_proxy_logger.info(
f"user_defined_values for creating ui key: {user_defined_values}"
)
response = await generate_key_helper_fn(
**default_ui_key_values, **user_defined_values # type: ignore
) )
key = response["token"] # type: ignore key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore user_id = response["user_id"] # type: ignore

View file

@ -554,6 +554,11 @@ class PrismaClient:
f"PrismaClient: find_unique for token: {hashed_token}" f"PrismaClient: find_unique for token: {hashed_token}"
) )
if query_type == "find_unique": if query_type == "find_unique":
if token is None:
raise HTTPException(
status_code=400,
detail={"error": f"No token passed in. Token={token}"},
)
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token} where={"token": hashed_token}
) )

View file

@ -830,8 +830,8 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m" f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m"
) )
if model_name is not None: if model is not None:
self.fail_calls[model_name] += 1 self.fail_calls[model] += 1
raise e raise e
def embedding( def embedding(

View file

@ -206,6 +206,7 @@ def test_async_custom_handler_stream():
# test_async_custom_handler_stream() # test_async_custom_handler_stream()
@pytest.mark.skip(reason="Flaky test")
def test_azure_completion_stream(): def test_azure_completion_stream():
# [PROD Test] - Do not DELETE # [PROD Test] - Do not DELETE
# test if completion() + sync custom logger get the same complete stream response # test if completion() + sync custom logger get the same complete stream response

View file

@ -1655,6 +1655,202 @@ def test_openai_streaming_and_function_calling():
raise e raise e
# test_azure_streaming_and_function_calling()
def test_success_callback_streaming():
def success_callback(kwargs, completion_response, start_time, end_time):
print(
{
"success": True,
"input": kwargs,
"output": completion_response,
"start_time": start_time,
"end_time": end_time,
}
)
litellm.success_callback = [success_callback]
messages = [{"role": "user", "content": "hello"}]
print("TESTING LITELLM COMPLETION CALL")
response = litellm.completion(
model="j2-light",
messages=messages,
stream=True,
max_tokens=5,
)
print(response)
for chunk in response:
print(chunk["choices"][0])
# test_success_callback_streaming()
#### STREAMING + FUNCTION CALLING ###
from pydantic import BaseModel
from typing import List, Optional
class Function(BaseModel):
name: str
arguments: str
class ToolCalls(BaseModel):
index: int
id: str
type: str
function: Function
class Delta(BaseModel):
role: str
content: Optional[str]
tool_calls: List[ToolCalls]
class Choices(BaseModel):
index: int
delta: Delta
logprobs: Optional[str]
finish_reason: Optional[str]
class Chunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choices]
def validate_first_streaming_function_calling_chunk(chunk: ModelResponse):
chunk_instance = Chunk(**chunk.model_dump())
### Chunk 1
# {
# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs",
# "object": "chat.completion.chunk",
# "created": 1708747267,
# "model": "gpt-3.5-turbo-0125",
# "system_fingerprint": "fp_86156a94a0",
# "choices": [
# {
# "index": 0,
# "delta": {
# "role": "assistant",
# "content": null,
# "tool_calls": [
# {
# "index": 0,
# "id": "call_oN10vaaC9iA8GLFRIFwjCsN7",
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "arguments": ""
# }
# }
# ]
# },
# "logprobs": null,
# "finish_reason": null
# }
# ]
# }
class Function2(BaseModel):
arguments: str
class ToolCalls2(BaseModel):
index: int
function: Optional[Function2]
class Delta2(BaseModel):
tool_calls: List[ToolCalls2]
class Choices2(BaseModel):
index: int
delta: Delta2
logprobs: Optional[str]
finish_reason: Optional[str]
class Chunk2(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choices2]
## Chunk 2
# {
# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs",
# "object": "chat.completion.chunk",
# "created": 1708747267,
# "model": "gpt-3.5-turbo-0125",
# "system_fingerprint": "fp_86156a94a0",
# "choices": [
# {
# "index": 0,
# "delta": {
# "tool_calls": [
# {
# "index": 0,
# "function": {
# "arguments": "{\""
# }
# }
# ]
# },
# "logprobs": null,
# "finish_reason": null
# }
# ]
# }
def validate_second_streaming_function_calling_chunk(chunk: ModelResponse):
chunk_instance = Chunk2(**chunk.model_dump())
class Delta3(BaseModel):
content: Optional[str] = None
role: Optional[str] = None
function_call: Optional[dict] = None
tool_calls: Optional[List] = None
class Choices3(BaseModel):
index: int
delta: Delta3
logprobs: Optional[str]
finish_reason: str
class Chunk3(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choices3]
def validate_final_streaming_function_calling_chunk(chunk: ModelResponse):
chunk_instance = Chunk3(**chunk.model_dump())
def test_azure_streaming_and_function_calling(): def test_azure_streaming_and_function_calling():
tools = [ tools = [
{ {
@ -1690,6 +1886,7 @@ def test_azure_streaming_and_function_calling():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
print(f"chunk: {chunk}")
if idx == 0: if idx == 0:
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -1697,40 +1894,69 @@ def test_azure_streaming_and_function_calling():
assert isinstance( assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str chunk.choices[0].delta.tool_calls[0].function.arguments, str
) )
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e
# test_azure_streaming_and_function_calling() @pytest.mark.asyncio
async def test_azure_astreaming_and_function_calling():
tools = [
def test_success_callback_streaming():
def success_callback(kwargs, completion_response, start_time, end_time):
print(
{ {
"success": True, "type": "function",
"input": kwargs, "function": {
"output": completion_response, "name": "get_current_weather",
"start_time": start_time, "description": "Get the current weather in a given location",
"end_time": end_time, "parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
} }
) ]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
litellm.success_callback = [success_callback] try:
response = await litellm.acompletion(
messages = [{"role": "user", "content": "hello"}] model="azure/gpt-4-nov-release",
print("TESTING LITELLM COMPLETION CALL") tools=tools,
response = litellm.completion( tool_choice="auto",
model="j2-light",
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=5, api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview",
) )
print(response) # Add any assertions here to check the response
idx = 0
async for chunk in response:
print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
for chunk in response: except Exception as e:
print(chunk["choices"][0]) pytest.fail(f"Error occurred: {e}")
raise e
# test_success_callback_streaming()

View file

@ -376,11 +376,9 @@ class StreamingChoices(OpenAIObject):
self.delta = delta self.delta = delta
else: else:
self.delta = Delta() self.delta = Delta()
if logprobs is not None:
self.logprobs = logprobs
if enhancements is not None: if enhancements is not None:
self.enhancements = enhancements self.enhancements = enhancements
self.logprobs = logprobs
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
@ -820,6 +818,8 @@ class Logging:
## DYNAMIC LANGFUSE KEYS ## ## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret self.langfuse_secret = langfuse_secret
## TIME TO FIRST TOKEN LOGGING ##
self.completion_start_time: Optional[datetime.datetime] = None
def update_environment_variables( def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params self, model, user, optional_params, litellm_params, **additional_params
@ -840,6 +840,7 @@ class Logging:
"user": user, "user": user,
"call_type": str(self.call_type), "call_type": str(self.call_type),
"litellm_call_id": self.litellm_call_id, "litellm_call_id": self.litellm_call_id,
"completion_start_time": self.completion_start_time,
**self.optional_params, **self.optional_params,
**additional_params, **additional_params,
} }
@ -1069,6 +1070,11 @@ class Logging:
start_time = self.start_time start_time = self.start_time
if end_time is None: if end_time is None:
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
@ -1358,7 +1364,7 @@ class Logging:
f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}" f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}"
) )
if complete_streaming_response is None: if complete_streaming_response is None:
break continue
else: else:
print_verbose("reaches langfuse for streaming logging!") print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
@ -8629,6 +8635,10 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = response_obj[ model_response.choices[0].finish_reason = response_obj[
"finish_reason" "finish_reason"
] ]
if response_obj.get("original_chunk", None) is not None:
model_response.system_fingerprint = getattr(
response_obj["original_chunk"], "system_fingerprint", None
)
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.26.14.dev1" version = "1.27.1"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.26.14.dev1" version = "1.27.1"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -97,10 +97,10 @@ async def test_chat_completion_old_key():
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
try: try:
key = "sk-yNXvlRO4SxIGG0XnRMYxTw" key = "sk-ecMXHujzUtKCvHcwacdaTw"
await chat_completion(session=session, key=key) await chat_completion(session=session, key=key)
except Exception as e: except Exception as e:
key = "sk-2KV0sAElLQqMpLZXdNf3yw" # try diff db key (in case db url is for the other db) key = "sk-ecMXHujzUtKCvHcwacdaTw" # try diff db key (in case db url is for the other db)
await chat_completion(session=session, key=key) await chat_completion(session=session, key=key)

View file

@ -1,17 +1,26 @@
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
import { Card, Title, Table, TableHead, TableRow, TableCell, TableBody, Grid, Tab, import {
Card,
Title,
Table,
TableHead,
TableRow,
TableCell,
TableBody,
Grid,
Tab,
TabGroup, TabGroup,
TabList, TabList,
TabPanel, TabPanel,
Metric, Metric,
Select, Select,
SelectItem, SelectItem,
TabPanels, } from "@tremor/react"; TabPanels,
import { modelInfoCall } from "./networking"; } from "@tremor/react";
import { modelAvailableCall } from "./networking";
import openai from "openai"; import openai from "openai";
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
interface ChatUIProps { interface ChatUIProps {
accessToken: string | null; accessToken: string | null;
@ -20,11 +29,18 @@ interface ChatUIProps {
userID: string | null; userID: string | null;
} }
async function generateModelResponse(inputMessage: string, updateUI: (chunk: string) => void, selectedModel: string, accessToken: string) { async function generateModelResponse(
inputMessage: string,
updateUI: (chunk: string) => void,
selectedModel: string,
accessToken: string
) {
// base url should be the current base_url // base url should be the current base_url
const isLocal = process.env.NODE_ENV === "development"; const isLocal = process.env.NODE_ENV === "development";
console.log("isLocal:", isLocal); console.log("isLocal:", isLocal);
const proxyBaseUrl = isLocal ? "http://localhost:4000" : window.location.origin; const proxyBaseUrl = isLocal
? "http://localhost:4000"
: window.location.origin;
const client = new openai.OpenAI({ const client = new openai.OpenAI({
apiKey: accessToken, // Replace with your OpenAI API key apiKey: accessToken, // Replace with your OpenAI API key
baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL
@ -36,7 +52,7 @@ async function generateModelResponse(inputMessage: string, updateUI: (chunk: str
stream: true, stream: true,
messages: [ messages: [
{ {
role: 'user', role: "user",
content: inputMessage, content: inputMessage,
}, },
], ],
@ -50,10 +66,17 @@ async function generateModelResponse(inputMessage: string, updateUI: (chunk: str
} }
} }
const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID }) => { const ChatUI: React.FC<ChatUIProps> = ({
accessToken,
token,
userRole,
userID,
}) => {
const [inputMessage, setInputMessage] = useState(""); const [inputMessage, setInputMessage] = useState("");
const [chatHistory, setChatHistory] = useState<any[]>([]); const [chatHistory, setChatHistory] = useState<any[]>([]);
const [selectedModel, setSelectedModel] = useState<string | undefined>(undefined); const [selectedModel, setSelectedModel] = useState<string | undefined>(
undefined
);
const [modelInfo, setModelInfo] = useState<any | null>(null); // Declare modelInfo at the component level const [modelInfo, setModelInfo] = useState<any | null>(null); // Declare modelInfo at the component level
useEffect(() => { useEffect(() => {
@ -62,12 +85,16 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
} }
// Fetch model info and set the default selected model // Fetch model info and set the default selected model
const fetchModelInfo = async () => { const fetchModelInfo = async () => {
const fetchedModelInfo = await modelInfoCall(accessToken, userID, userRole); const fetchedAvailableModels = await modelAvailableCall(
console.log("model_info:", fetchedModelInfo); accessToken,
userID,
userRole
);
console.log("model_info:", fetchedAvailableModels);
if (fetchedModelInfo?.data.length > 0) { if (fetchedAvailableModels?.data.length > 0) {
setModelInfo(fetchedModelInfo); setModelInfo(fetchedAvailableModels.data);
setSelectedModel(fetchedModelInfo.data[0].model_name); setSelectedModel(fetchedAvailableModels.data[0].id);
} }
}; };
@ -103,7 +130,12 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
try { try {
if (selectedModel) { if (selectedModel) {
await generateModelResponse(inputMessage, (chunk) => updateUI("assistant", chunk), selectedModel, accessToken); await generateModelResponse(
inputMessage,
(chunk) => updateUI("assistant", chunk),
selectedModel,
accessToken
);
} }
} catch (error) { } catch (error) {
console.error("Error fetching model response", error); console.error("Error fetching model response", error);
@ -132,14 +164,21 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
onChange={(e) => setSelectedModel(e.target.value)} onChange={(e) => setSelectedModel(e.target.value)}
> >
{/* Populate dropdown options from available models */} {/* Populate dropdown options from available models */}
{modelInfo?.data.map((element: { model_name: string }) => ( {modelInfo?.map((element: { id: string }) => (
<option key={element.model_name} value={element.model_name}> <option key={element.id} value={element.id}>
{element.model_name} {element.id}
</option> </option>
))} ))}
</select> </select>
</div> </div>
<Table className="mt-5" style={{ display: "block", maxHeight: "60vh", overflowY: "auto" }}> <Table
className="mt-5"
style={{
display: "block",
maxHeight: "60vh",
overflowY: "auto",
}}
>
<TableHead> <TableHead>
<TableRow> <TableRow>
<TableCell> <TableCell>
@ -155,7 +194,10 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
))} ))}
</TableBody> </TableBody>
</Table> </Table>
<div className="mt-3" style={{ position: "absolute", bottom: 5, width: "95%" }}> <div
className="mt-3"
style={{ position: "absolute", bottom: 5, width: "95%" }}
>
<div className="flex"> <div className="flex">
<input <input
type="text" type="text"
@ -164,7 +206,10 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
className="flex-1 p-2 border rounded-md mr-2" className="flex-1 p-2 border rounded-md mr-2"
placeholder="Type your message..." placeholder="Type your message..."
/> />
<button onClick={handleSendMessage} className="p-2 bg-blue-500 text-white rounded-md"> <button
onClick={handleSendMessage}
className="p-2 bg-blue-500 text-white rounded-md"
>
Send Send
</button> </button>
</div> </div>
@ -179,7 +224,6 @@ const ChatUI: React.FC<ChatUIProps> = ({ accessToken, token, userRole, userID })
</TabList> </TabList>
<TabPanels> <TabPanels>
<TabPanel> <TabPanel>
<SyntaxHighlighter language="python"> <SyntaxHighlighter language="python">
{` {`
import openai import openai
@ -211,7 +255,6 @@ print(response)
</SyntaxHighlighter> </SyntaxHighlighter>
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
<SyntaxHighlighter language="python"> <SyntaxHighlighter language="python">
{` {`
import os, dotenv import os, dotenv
@ -248,7 +291,6 @@ print(response)
</SyntaxHighlighter> </SyntaxHighlighter>
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
<SyntaxHighlighter language="python"> <SyntaxHighlighter language="python">
{` {`
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
@ -290,7 +332,6 @@ print(response)
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</TabGroup> </TabGroup>
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</TabGroup> </TabGroup>
@ -300,6 +341,4 @@ print(response)
); );
}; };
export default ChatUI; export default ChatUI;

View file

@ -1,10 +1,17 @@
"use client"; "use client";
import React, { useState, useEffect, useRef } from "react"; import React, { useState, useEffect, useRef } from "react";
import { Button, TextInput, Grid, Col } from "@tremor/react"; import { Button, TextInput, Grid, Col } from "@tremor/react";
import { Card, Metric, Text } from "@tremor/react"; import { Card, Metric, Text } from "@tremor/react";
import { Button as Button2, Modal, Form, Input, InputNumber, Select, message } from "antd"; import {
Button as Button2,
Modal,
Form,
Input,
InputNumber,
Select,
message,
} from "antd";
import { keyCreateCall } from "./networking"; import { keyCreateCall } from "./networking";
const { Option } = Select; const { Option } = Select;
@ -50,13 +57,12 @@ const CreateKey: React.FC<CreateKeyProps> = ({
setApiKey(response["key"]); setApiKey(response["key"]);
message.success("API Key Created"); message.success("API Key Created");
form.resetFields(); form.resetFields();
localStorage.removeItem("userData" + userID) localStorage.removeItem("userData" + userID);
} catch (error) { } catch (error) {
console.error("Error creating the key:", error); console.error("Error creating the key:", error);
} }
}; };
return ( return (
<div> <div>
<Button className="mx-auto" onClick={() => setIsModalVisible(true)}> <Button className="mx-auto" onClick={() => setIsModalVisible(true)}>
@ -70,30 +76,26 @@ const CreateKey: React.FC<CreateKeyProps> = ({
onOk={handleOk} onOk={handleOk}
onCancel={handleCancel} onCancel={handleCancel}
> >
<Form form={form} onFinish={handleCreate} labelCol={{ span: 8 }} wrapperCol={{ span: 16 }} labelAlign="left"> <Form
{userRole === 'App Owner' || userRole === 'Admin' ? ( form={form}
<> onFinish={handleCreate}
labelCol={{ span: 8 }}
<Form.Item wrapperCol={{ span: 16 }}
label="Key Name" labelAlign="left"
name="key_alias"
> >
{userRole === "App Owner" || userRole === "Admin" ? (
<>
<Form.Item label="Key Name" name="key_alias">
<Input /> <Input />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Team ID" name="team_id">
label="Team ID"
name="team_id"
>
<Input placeholder="ai_team" /> <Input placeholder="ai_team" />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Models" name="models">
label="Models"
name="models"
>
<Select <Select
mode="multiple" mode="multiple"
placeholder="Select models" placeholder="Select models"
style={{ width: '100%' }} style={{ width: "100%" }}
> >
{userModels.map((model) => ( {userModels.map((model) => (
<Option key={model} value={model}> <Option key={model} value={model}>
@ -102,16 +104,10 @@ const CreateKey: React.FC<CreateKeyProps> = ({
))} ))}
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Max Budget (USD)" name="max_budget">
label="Max Budget (USD)"
name="max_budget"
>
<InputNumber step={0.01} precision={2} width={200} /> <InputNumber step={0.01} precision={2} width={200} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Tokens per minute Limit (TPM)" name="tpm_limit">
label="Tokens per minute Limit (TPM)"
name="tpm_limit"
>
<InputNumber step={1} width={400} /> <InputNumber step={1} width={400} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item
@ -120,47 +116,29 @@ const CreateKey: React.FC<CreateKeyProps> = ({
> >
<InputNumber step={1} width={400} /> <InputNumber step={1} width={400} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Duration (eg: 30s, 30h, 30d)" name="duration">
label="Duration (eg: 30s, 30h, 30d)"
name="duration"
>
<Input /> <Input />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Metadata" name="metadata">
label="Metadata"
name="metadata"
>
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" /> <Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
</Form.Item> </Form.Item>
</> </>
) : ( ) : (
<> <>
<Form.Item <Form.Item label="Key Name" name="key_alias">
label="Key Name"
name="key_alias"
>
<Input /> <Input />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Team ID (Contact Group)" name="team_id">
label="Team ID (Contact Group)"
name="team_id"
>
<Input placeholder="ai_team" /> <Input placeholder="ai_team" />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Description" name="description">
label="Description"
name="description"
>
<Input.TextArea placeholder="Enter description" rows={4} /> <Input.TextArea placeholder="Enter description" rows={4} />
</Form.Item> </Form.Item>
</> </>
) )}
} <div style={{ textAlign: "right", marginTop: "10px" }}>
<div style={{ textAlign: 'right', marginTop: '10px' }}> <Button2 htmlType="submit">Create Key</Button2>
<Button2 htmlType="submit">
Create Key
</Button2>
</div> </div>
</Form> </Form>
</Modal> </Modal>
@ -177,8 +155,8 @@ const CreateKey: React.FC<CreateKeyProps> = ({
<p> <p>
Please save this secret key somewhere safe and accessible. For Please save this secret key somewhere safe and accessible. For
security reasons, <b>you will not be able to view it again</b>{" "} security reasons, <b>you will not be able to view it again</b>{" "}
through your LiteLLM account. If you lose this secret key, you will through your LiteLLM account. If you lose this secret key, you
need to generate a new one. will need to generate a new one.
</p> </p>
</Col> </Col>
<Col numColSpan={1}> <Col numColSpan={1}>

View file

@ -1,7 +1,7 @@
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import { Button, Modal, Form, Input, message, Select, InputNumber } from "antd"; import { Button, Modal, Form, Input, message, Select, InputNumber } from "antd";
import { Button as Button2 } from "@tremor/react"; import { Button as Button2 } from "@tremor/react";
import { userCreateCall, modelInfoCall } from "./networking"; import { userCreateCall, modelAvailableCall } from "./networking";
const { Option } = Select; const { Option } = Select;
interface CreateuserProps { interface CreateuserProps {
@ -20,12 +20,16 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
const fetchData = async () => { const fetchData = async () => {
try { try {
const userRole = "any"; // You may need to get the user role dynamically const userRole = "any"; // You may need to get the user role dynamically
const modelDataResponse = await modelInfoCall(accessToken, userID, userRole); const modelDataResponse = await modelAvailableCall(
accessToken,
userID,
userRole
);
// Assuming modelDataResponse.data contains an array of model objects with a 'model_name' property // Assuming modelDataResponse.data contains an array of model objects with a 'model_name' property
const availableModels = []; const availableModels = [];
for (let i = 0; i < modelDataResponse.data.length; i++) { for (let i = 0; i < modelDataResponse.data.length; i++) {
const model = modelDataResponse.data[i]; const model = modelDataResponse.data[i];
availableModels.push(model.model_name); availableModels.push(model.id);
} }
console.log("Model data response:", modelDataResponse.data); console.log("Model data response:", modelDataResponse.data);
console.log("Available models:", availableModels); console.log("Available models:", availableModels);
@ -79,27 +83,24 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
onOk={handleOk} onOk={handleOk}
onCancel={handleCancel} onCancel={handleCancel}
> >
<Form form={form} onFinish={handleCreate} labelCol={{ span: 8 }} wrapperCol={{ span: 16 }} labelAlign="left"> <Form
<Form.Item form={form}
label="User ID" onFinish={handleCreate}
name="user_id" labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
> >
<Form.Item label="User ID" name="user_id">
<Input placeholder="Enter User ID" /> <Input placeholder="Enter User ID" />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Team ID" name="team_id">
label="Team ID"
name="team_id"
>
<Input placeholder="ai_team" /> <Input placeholder="ai_team" />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Models" name="models">
label="Models"
name="models"
>
<Select <Select
mode="multiple" mode="multiple"
placeholder="Select models" placeholder="Select models"
style={{ width: '100%' }} style={{ width: "100%" }}
> >
{userModels.map((model) => ( {userModels.map((model) => (
<Option key={model} value={model}> <Option key={model} value={model}>
@ -109,46 +110,23 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item label="Max Budget (USD)" name="max_budget">
<Form.Item
label="Max Budget (USD)"
name="max_budget"
>
<InputNumber step={0.01} precision={2} width={200} /> <InputNumber step={0.01} precision={2} width={200} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Tokens per minute Limit (TPM)" name="tpm_limit">
label="Tokens per minute Limit (TPM)"
name="tpm_limit"
>
<InputNumber step={1} width={400} /> <InputNumber step={1} width={400} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Requests per minute Limit (RPM)" name="rpm_limit">
label="Requests per minute Limit (RPM)"
name="rpm_limit"
>
<InputNumber step={1} width={400} /> <InputNumber step={1} width={400} />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Duration (eg: 30s, 30h, 30d)" name="duration">
label="Duration (eg: 30s, 30h, 30d)"
name="duration"
>
<Input /> <Input />
</Form.Item> </Form.Item>
<Form.Item <Form.Item label="Metadata" name="metadata">
label="Metadata"
name="metadata"
>
<Input.TextArea rows={4} placeholder="Enter metadata as JSON" /> <Input.TextArea rows={4} placeholder="Enter metadata as JSON" />
</Form.Item> </Form.Item>
<div style={{ textAlign: 'right', marginTop: '10px' }}> <div style={{ textAlign: "right", marginTop: "10px" }}>
<Button htmlType="submit"> <Button htmlType="submit">Create User</Button>
Create User
</Button>
</div> </div>
</Form> </Form>
</Modal> </Modal>
@ -162,11 +140,15 @@ const Createuser: React.FC<CreateuserProps> = ({ userID, accessToken }) => {
> >
<p> <p>
Please save this secret user somewhere safe and accessible. For Please save this secret user somewhere safe and accessible. For
security reasons, <b>you will not be able to view it again</b> through security reasons, <b>you will not be able to view it again</b>{" "}
your LiteLLM account. If you lose this secret user, you will need to through your LiteLLM account. If you lose this secret user, you will
generate a new one. need to generate a new one.
</p>
<p>
{apiuser != null
? `API user: ${apiuser}`
: "User being created, this might take 30s"}
</p> </p>
<p>{apiuser != null ? `API user: ${apiuser}` : "User being created, this might take 30s"}</p>
</Modal> </Modal>
)} )}
</div> </div>

View file

@ -69,7 +69,6 @@ export const keyCreateCall = async (
} }
}; };
export const userCreateCall = async ( export const userCreateCall = async (
accessToken: string, accessToken: string,
userID: string, userID: string,
@ -133,7 +132,6 @@ export const userCreateCall = async (
} }
}; };
export const keyDeleteCall = async (accessToken: String, user_key: String) => { export const keyDeleteCall = async (accessToken: String, user_key: String) => {
try { try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/key/delete` : `/key/delete`; const url = proxyBaseUrl ? `${proxyBaseUrl}/key/delete` : `/key/delete`;
@ -207,13 +205,14 @@ export const userInfoCall = async (
} }
}; };
export const modelInfoCall = async ( export const modelInfoCall = async (
accessToken: String, accessToken: String,
userID: String, userID: String,
userRole: String userRole: String
) => { ) => {
/**
* Get all models on proxy
*/
try { try {
let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`; let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`;
@ -242,6 +241,42 @@ export const modelInfoCall = async (
} }
}; };
export const modelAvailableCall = async (
accessToken: String,
userID: String,
userRole: String
) => {
/**
* Get all the models user has access to
*/
try {
let url = proxyBaseUrl ? `${proxyBaseUrl}/models` : `/models`;
message.info("Requesting model data");
const response = await fetch(url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
message.error(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
message.info("Received model data");
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to create key:", error);
throw error;
}
};
export const keySpendLogsCall = async (accessToken: String, token: String) => { export const keySpendLogsCall = async (accessToken: String, token: String) => {
try { try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/spend/logs` : `/spend/logs`; const url = proxyBaseUrl ? `${proxyBaseUrl}/spend/logs` : `/spend/logs`;
@ -363,12 +398,16 @@ export const spendUsersCall = async (accessToken: String, userID: String) => {
} }
}; };
export const userRequestModelCall = async (
accessToken: String,
model: String,
export const userRequestModelCall = async (accessToken: String, model: String, UserID: String, justification: String) => { UserID: String,
justification: String
) => {
try { try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/user/request_model` : `/user/request_model`; const url = proxyBaseUrl
? `${proxyBaseUrl}/user/request_model`
: `/user/request_model`;
const response = await fetch(url, { const response = await fetch(url, {
method: "POST", method: "POST",
headers: { headers: {
@ -398,10 +437,11 @@ export const userRequestModelCall = async (accessToken: String, model: String, U
} }
}; };
export const userGetRequesedtModelsCall = async (accessToken: String) => { export const userGetRequesedtModelsCall = async (accessToken: String) => {
try { try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/user/get_requests` : `/user/get_requests`; const url = proxyBaseUrl
? `${proxyBaseUrl}/user/get_requests`
: `/user/get_requests`;
console.log("in userGetRequesedtModelsCall:", url); console.log("in userGetRequesedtModelsCall:", url);
const response = await fetch(url, { const response = await fetch(url, {
method: "GET", method: "GET",

View file

@ -1,6 +1,6 @@
"use client"; "use client";
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import { userInfoCall, modelInfoCall } from "./networking"; import { userInfoCall, modelAvailableCall } from "./networking";
import { Grid, Col, Card, Text } from "@tremor/react"; import { Grid, Col, Card, Text } from "@tremor/react";
import CreateKey from "./create_key_button"; import CreateKey from "./create_key_button";
import ViewKeyTable from "./view_key_table"; import ViewKeyTable from "./view_key_table";
@ -48,10 +48,9 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
const token = searchParams.get("token"); const token = searchParams.get("token");
const [accessToken, setAccessToken] = useState<string | null>(null); const [accessToken, setAccessToken] = useState<string | null>(null);
const [userModels, setUserModels] = useState<string[]>([]); const [userModels, setUserModels] = useState<string[]>([]);
// check if window is not undefined // check if window is not undefined
if (typeof window !== "undefined") { if (typeof window !== "undefined") {
window.addEventListener('beforeunload', function() { window.addEventListener("beforeunload", function () {
// Clear session storage // Clear session storage
sessionStorage.clear(); sessionStorage.clear();
}); });
@ -78,7 +77,6 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
// Moved useEffect inside the component and used a condition to run fetch only if the params are available // Moved useEffect inside the component and used a condition to run fetch only if the params are available
useEffect(() => { useEffect(() => {
if (token) { if (token) {
const decoded = jwtDecode(token) as { [key: string]: any }; const decoded = jwtDecode(token) as { [key: string]: any };
if (decoded) { if (decoded) {
@ -109,32 +107,39 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
const cachedUserModels = sessionStorage.getItem("userModels" + userID); const cachedUserModels = sessionStorage.getItem("userModels" + userID);
if (cachedUserModels) { if (cachedUserModels) {
setUserModels(JSON.parse(cachedUserModels)); setUserModels(JSON.parse(cachedUserModels));
} else { } else {
const fetchData = async () => { const fetchData = async () => {
try { try {
const response = await userInfoCall(accessToken, userID, userRole); const response = await userInfoCall(accessToken, userID, userRole);
setUserSpendData(response["user_info"]); setUserSpendData(response["user_info"]);
setData(response["keys"]); // Assuming this is the correct path to your data setData(response["keys"]); // Assuming this is the correct path to your data
sessionStorage.setItem("userData" + userID, JSON.stringify(response["keys"])); sessionStorage.setItem(
"userData" + userID,
JSON.stringify(response["keys"])
);
sessionStorage.setItem( sessionStorage.setItem(
"userSpendData" + userID, "userSpendData" + userID,
JSON.stringify(response["user_info"]) JSON.stringify(response["user_info"])
); );
const model_info = await modelInfoCall(accessToken, userID, userRole); const model_available = await modelAvailableCall(
console.log("model_info:", model_info); accessToken,
userID,
userRole
);
// loop through model_info["data"] and create an array of element.model_name // loop through model_info["data"] and create an array of element.model_name
let available_model_names = model_info["data"].filter((element: { model_name: string; user_access: boolean }) => element.user_access === true).map((element: { model_name: string; }) => element.model_name); let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id
);
console.log("available_model_names:", available_model_names); console.log("available_model_names:", available_model_names);
setUserModels(available_model_names); setUserModels(available_model_names);
console.log("userModels:", userModels); console.log("userModels:", userModels);
sessionStorage.setItem("userModels" + userID, JSON.stringify(available_model_names)); sessionStorage.setItem(
"userModels" + userID,
JSON.stringify(available_model_names)
);
} catch (error) { } catch (error) {
console.error("There was an error fetching the data", error); console.error("There was an error fetching the data", error);
// Optionally, update your UI to reflect the error state here as well // Optionally, update your UI to reflect the error state here as well