forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_team_settings
This commit is contained in:
commit
856aa9c30b
12 changed files with 681 additions and 113 deletions
|
@ -37,12 +37,12 @@ http://0.0.0.0:8000/ui # <proxy_base_url>/ui
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Get Admin UI Link on Swagger
|
### 3. Get Admin UI Link on Swagger
|
||||||
Your Proxy Swagger is available on the root of the Proxy: e.g.: `http://localhost:4000/`
|
Your Proxy Swagger is available on the root of the Proxy: e.g.: `http://localhost:4000/`
|
||||||
|
|
||||||
<Image img={require('../../img/ui_link.png')} />
|
<Image img={require('../../img/ui_link.png')} />
|
||||||
|
|
||||||
## Change default username + password
|
### 4. Change default username + password
|
||||||
|
|
||||||
Set the following in your .env on the Proxy
|
Set the following in your .env on the Proxy
|
||||||
|
|
||||||
|
@ -111,6 +111,29 @@ MICROSOFT_TENANT="5a39737
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="Generic" label="Generic SSO Provider">
|
||||||
|
|
||||||
|
A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code
|
||||||
|
|
||||||
|
**Required .env variables on your Proxy**
|
||||||
|
```shell
|
||||||
|
|
||||||
|
GENERIC_CLIENT_ID = "******"
|
||||||
|
GENERIC_CLIENT_SECRET = "G*******"
|
||||||
|
GENERIC_AUTHORIZATION_ENDPOINT = "http://localhost:9090/auth"
|
||||||
|
GENERIC_TOKEN_ENDPOINT = "http://localhost:9090/token"
|
||||||
|
GENERIC_USERINFO_ENDPOINT = "http://localhost:9090/me"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Set Redirect URI, if your provider requires it
|
||||||
|
- Set a redirect url = `<your proxy base url>/sso/callback`
|
||||||
|
```shell
|
||||||
|
http://localhost:4000/sso/callback
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### Step 3. Test flow
|
### Step 3. Test flow
|
||||||
|
|
|
@ -197,7 +197,7 @@ from openai import OpenAI
|
||||||
# set api_key to send to proxy server
|
# set api_key to send to proxy server
|
||||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||||
|
|
||||||
response = openai.embeddings.create(
|
response = client.embeddings.create(
|
||||||
input=["hello from litellm"],
|
input=["hello from litellm"],
|
||||||
model="text-embedding-ada-002"
|
model="text-embedding-ada-002"
|
||||||
)
|
)
|
||||||
|
@ -281,6 +281,84 @@ print(query_result[:5])
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `/moderations`
|
||||||
|
|
||||||
|
|
||||||
|
### Request Format
|
||||||
|
Input, Output and Exceptions are mapped to the OpenAI format for all supported models
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||||
|
|
||||||
|
response = client.moderations.create(
|
||||||
|
input="hello from litellm",
|
||||||
|
model="text-moderation-stable"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="Curl" label="Curl Request">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/moderations' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Response Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "modr-8sFEN22QCziALOfWTa77TodNLgHwA",
|
||||||
|
"model": "text-moderation-007",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"categories": {
|
||||||
|
"harassment": false,
|
||||||
|
"harassment/threatening": false,
|
||||||
|
"hate": false,
|
||||||
|
"hate/threatening": false,
|
||||||
|
"self-harm": false,
|
||||||
|
"self-harm/instructions": false,
|
||||||
|
"self-harm/intent": false,
|
||||||
|
"sexual": false,
|
||||||
|
"sexual/minors": false,
|
||||||
|
"violence": false,
|
||||||
|
"violence/graphic": false
|
||||||
|
},
|
||||||
|
"category_scores": {
|
||||||
|
"harassment": 0.000019947197870351374,
|
||||||
|
"harassment/threatening": 5.5971017900446896e-6,
|
||||||
|
"hate": 0.000028560316422954202,
|
||||||
|
"hate/threatening": 2.2631787999216613e-8,
|
||||||
|
"self-harm": 2.9121162015144364e-7,
|
||||||
|
"self-harm/instructions": 9.314219084899378e-8,
|
||||||
|
"self-harm/intent": 8.093739012338119e-8,
|
||||||
|
"sexual": 0.00004414955765241757,
|
||||||
|
"sexual/minors": 0.0000156943697220413,
|
||||||
|
"violence": 0.00022354527027346194,
|
||||||
|
"violence/graphic": 8.804164281173144e-6
|
||||||
|
},
|
||||||
|
"flagged": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
|
|
@ -2962,16 +2962,39 @@ def text_completion(
|
||||||
|
|
||||||
|
|
||||||
##### Moderation #######################
|
##### Moderation #######################
|
||||||
def moderation(input: str, api_key: Optional[str] = None):
|
|
||||||
|
|
||||||
|
def moderation(
|
||||||
|
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
# only supports open ai for now
|
# only supports open ai for now
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
openai.api_key = api_key
|
|
||||||
openai.api_type = "open_ai" # type: ignore
|
openai_client = kwargs.get("client", None)
|
||||||
openai.api_version = None
|
if openai_client is None:
|
||||||
openai.base_url = "https://api.openai.com/v1/"
|
openai_client = openai.OpenAI(
|
||||||
response = openai.moderations.create(input=input)
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.moderations.create(input=input, model=model)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
##### Moderation #######################
|
||||||
|
@client
|
||||||
|
async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs):
|
||||||
|
# only supports open ai for now
|
||||||
|
api_key = (
|
||||||
|
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
openai_client = kwargs.get("client", None)
|
||||||
|
if openai_client is None:
|
||||||
|
openai_client = openai.AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
response = await openai_client.moderations.create(input=input, model=model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,14 +9,19 @@ model_list:
|
||||||
mode: chat
|
mode: chat
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
base_model: azure/gpt-4-1106-preview
|
base_model: azure/gpt-4-1106-preview
|
||||||
|
access_groups: ["public"]
|
||||||
- model_name: openai-gpt-3.5
|
- model_name: openai-gpt-3.5
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
model_info:
|
||||||
|
access_groups: ["public"]
|
||||||
- model_name: anthropic-claude-v2.1
|
- model_name: anthropic-claude-v2.1
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/anthropic.claude-v2:1
|
model: bedrock/anthropic.claude-v2:1
|
||||||
timeout: 300 # sets a 5 minute timeout
|
timeout: 300 # sets a 5 minute timeout
|
||||||
|
model_info:
|
||||||
|
access_groups: ["private"]
|
||||||
- model_name: anthropic-claude-v2
|
- model_name: anthropic-claude-v2
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/anthropic.claude-v2
|
model: bedrock/anthropic.claude-v2
|
||||||
|
@ -32,19 +37,13 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
model_info:
|
model_info:
|
||||||
base_model: azure/gpt-4
|
base_model: azure/gpt-4
|
||||||
|
- model_name: text-moderation-stable
|
||||||
|
litellm_params:
|
||||||
|
model: text-moderation-stable
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
|
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
|
||||||
success_callback: ['langfuse']
|
success_callback: ['langfuse']
|
||||||
max_budget: 50 # global budget for proxy
|
|
||||||
max_user_budget: 0.0001
|
|
||||||
budget_duration: 30d # global budget duration, will reset after 30d
|
|
||||||
default_key_generate_params:
|
|
||||||
max_budget: 1.5000
|
|
||||||
models: ["azure-gpt-3.5"]
|
|
||||||
duration: None
|
|
||||||
upperbound_key_generate_params:
|
|
||||||
max_budget: 100
|
|
||||||
duration: "30d"
|
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
|
||||||
|
|
|
@ -403,34 +403,43 @@ async def user_api_key_auth(
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"LLM Model List pre access group check: {llm_model_list}"
|
f"LLM Model List pre access group check: {llm_model_list}"
|
||||||
)
|
)
|
||||||
access_groups = []
|
from collections import defaultdict
|
||||||
|
|
||||||
|
access_groups = defaultdict(list)
|
||||||
if llm_model_list is not None:
|
if llm_model_list is not None:
|
||||||
for m in llm_model_list:
|
for m in llm_model_list:
|
||||||
for group in m.get("model_info", {}).get("access_groups", []):
|
for group in m.get("model_info", {}).get("access_groups", []):
|
||||||
access_groups.append((m["model_name"], group))
|
model_name = m["model_name"]
|
||||||
|
access_groups[group].append(model_name)
|
||||||
|
|
||||||
allowed_models = valid_token.models
|
models_in_current_access_groups = []
|
||||||
access_group_idx = set()
|
|
||||||
if (
|
if (
|
||||||
len(access_groups) > 0
|
len(access_groups) > 0
|
||||||
): # check if token contains any model access groups
|
): # check if token contains any model access groups
|
||||||
for idx, m in enumerate(valid_token.models):
|
for idx, m in enumerate(
|
||||||
for model_name, group in access_groups:
|
valid_token.models
|
||||||
if m == group:
|
): # loop token models, if any of them are an access group add the access group
|
||||||
access_group_idx.add(idx)
|
if m in access_groups:
|
||||||
allowed_models.append(model_name)
|
# if it is an access group we need to remove it from valid_token.models
|
||||||
|
models_in_group = access_groups[m]
|
||||||
|
models_in_current_access_groups.extend(models_in_group)
|
||||||
|
|
||||||
|
# Filter out models that are access_groups
|
||||||
|
filtered_models = [
|
||||||
|
m for m in valid_token.models if m not in access_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_models += models_in_current_access_groups
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"model: {model}; allowed_models: {allowed_models}"
|
f"model: {model}; allowed_models: {filtered_models}"
|
||||||
)
|
)
|
||||||
if model is not None and model not in allowed_models:
|
if model is not None and model not in filtered_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||||
)
|
)
|
||||||
for val in access_group_idx:
|
valid_token.models = filtered_models
|
||||||
allowed_models.pop(val)
|
|
||||||
valid_token.models = allowed_models
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
|
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 2. If user_id for this token is in budget
|
# Check 2. If user_id for this token is in budget
|
||||||
|
@ -2087,14 +2096,6 @@ def model_list(
|
||||||
if user_model is not None:
|
if user_model is not None:
|
||||||
all_models += [user_model]
|
all_models += [user_model]
|
||||||
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
||||||
### CHECK OLLAMA MODELS ###
|
|
||||||
try:
|
|
||||||
response = requests.get("http://0.0.0.0:11434/api/tags")
|
|
||||||
models = response.json()["models"]
|
|
||||||
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models]
|
|
||||||
all_models.extend(ollama_models)
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
return dict(
|
return dict(
|
||||||
data=[
|
data=[
|
||||||
{
|
{
|
||||||
|
@ -2798,7 +2799,161 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#### KEY MANAGEMENT #####
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/moderations",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
tags=["moderations"],
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/moderations",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
tags=["moderations"],
|
||||||
|
)
|
||||||
|
async def moderations(
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
|
||||||
|
|
||||||
|
Quick Start
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/moderations' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
global proxy_logging_obj
|
||||||
|
try:
|
||||||
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
|
body = await request.body()
|
||||||
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = {
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
data["model"] = (
|
||||||
|
general_settings.get("moderation_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
|
if user_model:
|
||||||
|
data["model"] = user_model
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
router_model_names = (
|
||||||
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
if llm_model_list is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
|
# skip router if user passed their key
|
||||||
|
if "api_key" in data:
|
||||||
|
response = await litellm.amoderation(**data)
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in router_model_names
|
||||||
|
): # model in router model list
|
||||||
|
response = await llm_router.amoderation(**data)
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.amoderation(**data, specific_deployment=True)
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and llm_router.model_group_alias is not None
|
||||||
|
and data["model"] in llm_router.model_group_alias
|
||||||
|
): # model set in model_group_alias
|
||||||
|
response = await llm_router.amoderation(
|
||||||
|
**data
|
||||||
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||||
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
|
response = await litellm.amoderation(**data)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": "Invalid model name passed in"},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#### KEY MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -3684,7 +3839,6 @@ async def user_update(data: UpdateUserRequest):
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
@ -3766,75 +3920,182 @@ async def team_info(
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
get info on team + related keys
|
get info on team + related keys
|
||||||
|
|
||||||
```
|
|
||||||
curl --location 'http://localhost:4000/team/info' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"teams": ["<team-id>",..]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
pass
|
||||||
try:
|
|
||||||
if prisma_client is None:
|
@app.get("/sso/callback", tags=["experimental"])
|
||||||
raise HTTPException(
|
async def auth_callback(request: Request):
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
"""Verify login"""
|
||||||
detail={
|
global general_settings
|
||||||
"error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
},
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
)
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
if team_id is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
detail={"message": "Malformed request. No team id passed in."},
|
|
||||||
)
|
|
||||||
|
|
||||||
team_info = await prisma_client.get_data(
|
# get url from request
|
||||||
team_id=team_id, table_name="team", query_type="find_unique"
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
)
|
|
||||||
## GET ALL KEYS ##
|
|
||||||
keys = await prisma_client.get_data(
|
|
||||||
team_id=team_id,
|
|
||||||
table_name="key",
|
|
||||||
query_type="find_all",
|
|
||||||
expires=datetime.now(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if team_info is None:
|
if redirect_url.endswith("/"):
|
||||||
## make sure we still return a total spend ##
|
redirect_url += "sso/callback"
|
||||||
spend = 0
|
else:
|
||||||
for k in keys:
|
redirect_url += "/sso/callback"
|
||||||
spend += getattr(k, "spend", 0)
|
|
||||||
team_info = {"spend": spend}
|
|
||||||
|
|
||||||
## REMOVE HASHED TOKEN INFO before returning ##
|
if google_client_id is not None:
|
||||||
for key in keys:
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
try:
|
|
||||||
key = key.model_dump() # noqa
|
|
||||||
except:
|
|
||||||
# if using pydantic v1
|
|
||||||
key = key.dict()
|
|
||||||
key.pop("token", None)
|
|
||||||
return {"team_id": team_id, "team_info": team_info, "keys": keys}
|
|
||||||
|
|
||||||
except Exception as e:
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||||
if isinstance(e, HTTPException):
|
if google_client_secret is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||||
type="auth_error",
|
type="auth_error",
|
||||||
param=getattr(e, "param", "None"),
|
param="GOOGLE_CLIENT_SECRET",
|
||||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
elif isinstance(e, ProxyException):
|
google_sso = GoogleSSO(
|
||||||
raise e
|
client_id=google_client_id,
|
||||||
raise ProxyException(
|
redirect_uri=redirect_url,
|
||||||
message="Authentication Error, " + str(e),
|
client_secret=google_client_secret,
|
||||||
type="auth_error",
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
)
|
)
|
||||||
|
result = await google_sso.verify_and_process(request)
|
||||||
|
|
||||||
|
elif microsoft_client_id is not None:
|
||||||
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||||
|
|
||||||
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||||
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||||
|
if microsoft_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if microsoft_tenant is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="MICROSOFT_TENANT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
microsoft_sso = MicrosoftSSO(
|
||||||
|
client_id=microsoft_client_id,
|
||||||
|
client_secret=microsoft_client_secret,
|
||||||
|
tenant=microsoft_tenant,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
# make generic sso provider
|
||||||
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||||
|
|
||||||
|
request_body = await request.body()
|
||||||
|
|
||||||
|
request_query_params = request.query_params
|
||||||
|
|
||||||
|
# get "code" from query params
|
||||||
|
code = request_query_params.get("code")
|
||||||
|
|
||||||
|
result = await generic_sso.verify_and_process(request)
|
||||||
|
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||||
|
|
||||||
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
|
user_email = getattr(result, "email", None)
|
||||||
|
user_id = getattr(result, "id", None)
|
||||||
|
if user_id is None:
|
||||||
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
|
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
||||||
|
)
|
||||||
|
key = response["token"] # type: ignore
|
||||||
|
user_id = response["user_id"] # type: ignore
|
||||||
|
|
||||||
|
litellm_dashboard_ui = "/ui/"
|
||||||
|
|
||||||
|
user_role = "app_owner"
|
||||||
|
if (
|
||||||
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
|
):
|
||||||
|
# checks if user is admin
|
||||||
|
user_role = "app_admin"
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
jwt_token = jwt.encode(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"key": key,
|
||||||
|
"user_email": user_email,
|
||||||
|
"user_role": user_role,
|
||||||
|
},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
|
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||||
|
|
||||||
|
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
|
||||||
|
general_settings["allow_user_auth"] = True
|
||||||
|
return RedirectResponse(url=litellm_dashboard_ui)
|
||||||
|
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
@ -4260,6 +4521,73 @@ async def google_login(request: Request):
|
||||||
)
|
)
|
||||||
with microsoft_sso:
|
with microsoft_sso:
|
||||||
return await microsoft_sso.get_login_redirect()
|
return await microsoft_sso.get_login_redirect()
|
||||||
|
elif generic_client_id is not None:
|
||||||
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_authorization_endpoint = os.getenv(
|
||||||
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||||
|
)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type="auth_error",
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with generic_sso:
|
||||||
|
return await generic_sso.get_login_redirect()
|
||||||
|
|
||||||
|
elif ui_username is not None:
|
||||||
|
# No Google, Microsoft SSO
|
||||||
|
# Use UI Credentials set in .env
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
elif ui_username is not None:
|
elif ui_username is not None:
|
||||||
# No Google, Microsoft SSO
|
# No Google, Microsoft SSO
|
||||||
# Use UI Credentials set in .env
|
# Use UI Credentials set in .env
|
||||||
|
|
|
@ -93,7 +93,9 @@ class ProxyLogging:
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal[
|
||||||
|
"completion", "embeddings", "image_generation", "moderation"
|
||||||
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||||
|
|
|
@ -599,6 +599,98 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def amoderation(self, model: str, input: str, **kwargs):
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["input"] = input
|
||||||
|
kwargs["original_function"] = self._amoderation
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _amoderation(self, model: str, input: str, **kwargs):
|
||||||
|
model_name = None
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs and v is not None
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
|
if (
|
||||||
|
dynamic_api_key is not None
|
||||||
|
and potential_model_client is not None
|
||||||
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
|
):
|
||||||
|
model_client = None
|
||||||
|
else:
|
||||||
|
model_client = potential_model_client
|
||||||
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
timeout = (
|
||||||
|
data.get(
|
||||||
|
"timeout", None
|
||||||
|
) # timeout set on litellm_params for this deployment
|
||||||
|
or self.timeout # timeout set on router
|
||||||
|
or kwargs.get(
|
||||||
|
"timeout", None
|
||||||
|
) # this uses default_litellm_params when nothing is set
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await litellm.amoderation(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"input": input,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
"timeout": timeout,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model_name is not None:
|
||||||
|
self.fail_calls[model_name] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
def text_completion(
|
def text_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -86,7 +86,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
if isinstance(response_obj, ModelResponse):
|
if isinstance(response_obj, ModelResponse):
|
||||||
completion_tokens = response_obj.usage.completion_tokens
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
total_tokens = response_obj.usage.total_tokens
|
total_tokens = response_obj.usage.total_tokens
|
||||||
final_value = float(completion_tokens / response_ms.total_seconds())
|
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
|
@ -168,7 +168,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
if isinstance(response_obj, ModelResponse):
|
if isinstance(response_obj, ModelResponse):
|
||||||
completion_tokens = response_obj.usage.completion_tokens
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
total_tokens = response_obj.usage.total_tokens
|
total_tokens = response_obj.usage.total_tokens
|
||||||
final_value = float(completion_tokens / response_ms.total_seconds())
|
final_value = float(response_ms.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
|
|
|
@ -2093,10 +2093,6 @@ def test_completion_cloudflare():
|
||||||
|
|
||||||
|
|
||||||
def test_moderation():
|
def test_moderation():
|
||||||
import openai
|
|
||||||
|
|
||||||
openai.api_type = "azure"
|
|
||||||
openai.api_version = "GM"
|
|
||||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||||
print(response)
|
print(response)
|
||||||
output = response.results[0]
|
output = response.results[0]
|
||||||
|
|
|
@ -991,3 +991,23 @@ def test_router_timeout():
|
||||||
print(e)
|
print(e)
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_amoderation():
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-moderations",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "text-moderation-stable",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY", None),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
result = await router.amoderation(
|
||||||
|
model="openai-moderations", input="this is valid good text"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("moderation result", result)
|
||||||
|
|
|
@ -738,6 +738,8 @@ class CallTypes(Enum):
|
||||||
text_completion = "text_completion"
|
text_completion = "text_completion"
|
||||||
image_generation = "image_generation"
|
image_generation = "image_generation"
|
||||||
aimage_generation = "aimage_generation"
|
aimage_generation = "aimage_generation"
|
||||||
|
moderation = "moderation"
|
||||||
|
amoderation = "amoderation"
|
||||||
|
|
||||||
|
|
||||||
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
# Logging function -> log the exact model details + what's being sent | Non-BlockingP
|
||||||
|
@ -2100,6 +2102,11 @@ def client(original_function):
|
||||||
or call_type == CallTypes.aimage_generation.value
|
or call_type == CallTypes.aimage_generation.value
|
||||||
):
|
):
|
||||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.moderation.value
|
||||||
|
or call_type == CallTypes.amoderation.value
|
||||||
|
):
|
||||||
|
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||||
elif (
|
elif (
|
||||||
call_type == CallTypes.atext_completion.value
|
call_type == CallTypes.atext_completion.value
|
||||||
or call_type == CallTypes.text_completion.value
|
or call_type == CallTypes.text_completion.value
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.23.15"
|
version = "1.23.16"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -69,7 +69,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.23.15"
|
version = "1.23.16"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue