LiteLLM Minor Fixes & Improvements (09/20/2024) (#5807)

* fix(vertex_llm_base.py): Handle api_base = ""

Fixes https://github.com/BerriAI/litellm/issues/5798

* fix(o1_transformation.py): handle stream_options not being supported

https://github.com/BerriAI/litellm/issues/5803

* docs(routing.md): fix docs

Closes https://github.com/BerriAI/litellm/issues/5808

* perf(internal_user_endpoints.py): reduce db calls for getting team_alias for a key

Use the list gotten earlier in `/user/info` endpoint

 Reduces ui keys tab load time to 800ms (prev. 28s+)

* feat(proxy_server.py): support CONFIG_FILE_PATH as env var

Closes https://github.com/BerriAI/litellm/issues/5744

* feat(get_llm_provider_logic.py): add `litellm_proxy/` as a known openai-compatible route

simplifies calling litellm proxy

Reduces confusion when calling models on litellm proxy from litellm sdk

* docs(litellm_proxy.md): cleanup docs

* fix(internal_user_endpoints.py): fix pydantic obj

* test(test_key_generate_prisma.py): fix test
This commit is contained in:
Krish Dholakia 2024-09-20 20:21:32 -07:00 committed by GitHub
parent 0c488cf4ca
commit d6ca7fed18
14 changed files with 204 additions and 84 deletions

View file

@ -9,13 +9,13 @@ import TabItem from '@theme/TabItem';
::: :::
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `openai/` prefix before the model **[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
## Required Variables ## Required Variables
```python ```python
os.environ["OPENAI_API_KEY"] = "" # "sk-1234" your litellm proxy api key os.environ["LITELLM_PROXY_API_KEY"] = "" # "sk-1234" your litellm proxy api key
os.environ["OPENAI_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base os.environ["LITELLM_PROXY_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
``` ```
@ -25,18 +25,18 @@ import os
import litellm import litellm
from litellm import completion from litellm import completion
os.environ["OPENAI_API_KEY"] = "" os.environ["LITELLM_PROXY_API_KEY"] = ""
# set custom api base to your proxy # set custom api base to your proxy
# either set .env or litellm.api_base # either set .env or litellm.api_base
# os.environ["OPENAI_API_BASE"] = "" # os.environ["LITELLM_PROXY_API_BASE"] = ""
litellm.api_base = "your-openai-proxy-url" litellm.api_base = "your-openai-proxy-url"
messages = [{ "content": "Hello, how are you?","role": "user"}] messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call # litellm proxy call
response = completion(model="openai/your-model-name", messages) response = completion(model="litellm_proxy/your-model-name", messages)
``` ```
## Usage - passing `api_base`, `api_key` per request ## Usage - passing `api_base`, `api_key` per request
@ -48,13 +48,13 @@ import os
import litellm import litellm
from litellm import completion from litellm import completion
os.environ["OPENAI_API_KEY"] = "" os.environ["LITELLM_PROXY_API_KEY"] = ""
messages = [{ "content": "Hello, how are you?","role": "user"}] messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call # litellm proxy call
response = completion( response = completion(
model="openai/your-model-name", model="litellm_proxy/your-model-name",
messages, messages,
api_base = "your-litellm-proxy-url", api_base = "your-litellm-proxy-url",
api_key = "your-litellm-proxy-api-key" api_key = "your-litellm-proxy-api-key"
@ -67,13 +67,13 @@ import os
import litellm import litellm
from litellm import completion from litellm import completion
os.environ["OPENAI_API_KEY"] = "" os.environ["LITELLM_PROXY_API_KEY"] = ""
messages = [{ "content": "Hello, how are you?","role": "user"}] messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call # openai call
response = completion( response = completion(
model="openai/your-model-name", model="litellm_proxy/your-model-name",
messages, messages,
api_base = "your-litellm-proxy-url", api_base = "your-litellm-proxy-url",
stream=True stream=True

View file

@ -183,9 +183,9 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE")
"tpm": 100000,
"rpm": 10000,
}, },
"tpm": 100000,
"rpm": 10000,
}, { }, {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
@ -193,24 +193,24 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE")
"tpm": 100000,
"rpm": 1000,
}, },
"tpm": 100000,
"rpm": 1000,
}, { }, {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
"tpm": 100000,
"rpm": 1000,
}, },
"tpm": 100000,
"rpm": 1000,
}] }]
router = Router(model_list=model_list, router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"], redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"], redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"], redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
enable_pre_call_check=True, # enables router rate limits for concurrent calls enable_pre_call_checks=True, # enables router rate limits for concurrent calls
) )
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(model="gpt-3.5-turbo",

View file

@ -495,6 +495,7 @@ openai_compatible_providers: List = [
"friendliai", "friendliai",
"azure_ai", "azure_ai",
"github", "github",
"litellm_proxy",
] ]
openai_text_completion_compatible_providers: List = ( openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions` [ # providers that support `/v1/completions`
@ -748,6 +749,7 @@ class LlmProviders(str, Enum):
EMPOWER = "empower" EMPOWER = "empower"
GITHUB = "github" GITHUB = "github"
CUSTOM = "custom" CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

View file

@ -243,6 +243,9 @@ def get_llm_provider(
elif custom_llm_provider == "github": elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY") dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
elif custom_llm_provider == "litellm_proxy":
api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY")
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai # mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai

View file

@ -65,6 +65,7 @@ class OpenAIO1Config(OpenAIGPTConfig):
"top_logprobs", "top_logprobs",
"response_format", "response_format",
"stop", "stop",
"stream_options",
] ]
return [ return [

View file

@ -16,6 +16,7 @@ from litellm.types.llms.vertex_ai import (
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from ..common_utils import VertexAIError from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
from .transformation import ( from .transformation import (
separate_cached_messages, separate_cached_messages,
transform_openai_messages_to_gemini_context_caching, transform_openai_messages_to_gemini_context_caching,
@ -24,7 +25,7 @@ from .transformation import (
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
class ContextCachingEndpoints: class ContextCachingEndpoints(VertexBase):
""" """
Covers context caching endpoints for Vertex AI + Google AI Studio Covers context caching endpoints for Vertex AI + Google AI Studio
@ -34,7 +35,7 @@ class ContextCachingEndpoints:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def _get_token_and_url( def _get_token_and_url_context_caching(
self, self,
gemini_api_key: Optional[str], gemini_api_key: Optional[str],
custom_llm_provider: Literal["gemini"], custom_llm_provider: Literal["gemini"],
@ -57,18 +58,16 @@ class ContextCachingEndpoints:
else: else:
raise NotImplementedError raise NotImplementedError
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}/{}".format(api_base, endpoint)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
return auth_header, url return self._check_custom_proxy(
api_base=api_base,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=None,
auth_header=auth_header,
url=url,
)
def check_cache( def check_cache(
self, self,
@ -90,7 +89,7 @@ class ContextCachingEndpoints:
- None - None
""" """
_, url = self._get_token_and_url( _, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key, gemini_api_key=api_key,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_base=api_base, api_base=api_base,
@ -125,11 +124,12 @@ class ContextCachingEndpoints:
all_cached_items = CachedContentListAllResponseBody(**raw_response) all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
for cached_item in all_cached_items["cachedContents"]: for cached_item in all_cached_items["cachedContents"]:
if ( display_name = cached_item.get("displayName")
cached_item.get("displayName") is not None if display_name is not None and display_name == cache_key:
and cached_item["displayName"] == cache_key
):
return cached_item.get("name") return cached_item.get("name")
return None return None
@ -154,7 +154,7 @@ class ContextCachingEndpoints:
- None - None
""" """
_, url = self._get_token_and_url( _, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key, gemini_api_key=api_key,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_base=api_base, api_base=api_base,
@ -189,11 +189,12 @@ class ContextCachingEndpoints:
all_cached_items = CachedContentListAllResponseBody(**raw_response) all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
for cached_item in all_cached_items["cachedContents"]: for cached_item in all_cached_items["cachedContents"]:
if ( display_name = cached_item.get("displayName")
cached_item.get("displayName") is not None if display_name is not None and display_name == cache_key:
and cached_item["displayName"] == cache_key
):
return cached_item.get("name") return cached_item.get("name")
return None return None
@ -224,7 +225,7 @@ class ContextCachingEndpoints:
return messages, cached_content return messages, cached_content
## AUTHORIZATION ## ## AUTHORIZATION ##
token, url = self._get_token_and_url( token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key, gemini_api_key=api_key,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_base=api_base, api_base=api_base,
@ -329,7 +330,7 @@ class ContextCachingEndpoints:
return messages, cached_content return messages, cached_content
## AUTHORIZATION ## ## AUTHORIZATION ##
token, url = self._get_token_and_url( token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key, gemini_api_key=api_key,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_base=api_base, api_base=api_base,

View file

@ -86,7 +86,7 @@ class VertexBase(BaseLLM):
) )
if project_id is None: if project_id is None:
project_id = creds.project_id project_id = getattr(creds, "project_id", None)
else: else:
creds, creds_project_id = google_auth.default( creds, creds_project_id = google_auth.default(
quota_project_id=project_id, quota_project_id=project_id,
@ -95,7 +95,7 @@ class VertexBase(BaseLLM):
if project_id is None: if project_id is None:
project_id = creds_project_id project_id = creds_project_id
creds.refresh(Request()) creds.refresh(Request()) # type: ignore
if not project_id: if not project_id:
raise ValueError("Could not resolve project_id") raise ValueError("Could not resolve project_id")
@ -169,6 +169,39 @@ class VertexBase(BaseLLM):
return True return True
return False return False
def _check_custom_proxy(
self,
api_base: Optional[str],
custom_llm_provider: str,
gemini_api_key: Optional[str],
endpoint: str,
stream: Optional[bool],
auth_header: Optional[str],
url: str,
) -> Tuple[Optional[str], str]:
"""
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
## Returns
- (auth_header, url) - Tuple[Optional[str], str]
"""
if api_base:
if custom_llm_provider == "gemini":
url = "{}:{}".format(api_base, endpoint)
if gemini_api_key is None:
raise ValueError(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
def _get_token_and_url( def _get_token_and_url(
self, self,
model: str, model: str,
@ -215,25 +248,15 @@ class VertexBase(BaseLLM):
vertex_api_version=version, vertex_api_version=version,
) )
if ( return self._check_custom_proxy(
api_base is not None api_base=api_base,
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 auth_header=auth_header,
if custom_llm_provider == "gemini": custom_llm_provider=custom_llm_provider,
url = "{}:{}".format(api_base, endpoint) gemini_api_key=gemini_api_key,
if gemini_api_key is None: endpoint=endpoint,
raise ValueError( stream=stream,
"Missing gemini_api_key, please set `GEMINI_API_KEY`" url=url,
) )
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
async def _ensure_access_token_async( async def _ensure_access_token_async(
self, self,

View file

@ -1422,6 +1422,13 @@ class UserAPIKeyAuth(
arbitrary_types_allowed = True arbitrary_types_allowed = True
class UserInfoResponse(LiteLLMBase):
user_id: Optional[str]
user_info: Optional[Union[dict, BaseModel]]
keys: List
teams: List
class LiteLLM_Config(LiteLLMBase): class LiteLLM_Config(LiteLLMBase):
param_name: str param_name: str
param_value: Dict param_value: Dict

View file

@ -15,6 +15,7 @@ import copy
import json import json
import re import re
import secrets import secrets
import time
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -274,10 +275,23 @@ async def ui_get_available_role(
return _data_to_return return _data_to_return
def get_team_from_list(
team_list: Optional[List[LiteLLM_TeamTable]], team_id: str
) -> Optional[LiteLLM_TeamTable]:
if team_list is None:
return None
for team in team_list:
if team.team_id == team_id:
return team
return None
@router.get( @router.get(
"/user/info", "/user/info",
tags=["Internal User management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
response_model=UserInfoResponse,
) )
@management_endpoint_wrapper @management_endpoint_wrapper
async def user_info( async def user_info(
@ -337,7 +351,7 @@ async def user_info(
## GET ALL TEAMS ## ## GET ALL TEAMS ##
team_list = [] team_list = []
team_id_list = [] team_id_list = []
# _DEPRECATED_ check if user in 'member' field # get all teams user belongs to
teams_1 = await prisma_client.get_data( teams_1 = await prisma_client.get_data(
user_id=user_id, table_name="team", query_type="find_all" user_id=user_id, table_name="team", query_type="find_all"
) )
@ -414,13 +428,13 @@ async def user_info(
if ( if (
key.token == litellm_master_key_hash key.token == litellm_master_key_hash
and general_settings.get("disable_master_key_return", False) and general_settings.get("disable_master_key_return", False)
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
): ):
continue continue
try: try:
key = key.model_dump() # noqa key = key.model_dump() # noqa
except: except Exception:
# if using pydantic v1 # if using pydantic v1
key = key.dict() key = key.dict()
if ( if (
@ -428,29 +442,29 @@ async def user_info(
and key["team_id"] is not None and key["team_id"] is not None
and key["team_id"] != "litellm-dashboard" and key["team_id"] != "litellm-dashboard"
): ):
team_info = await prisma_client.get_data( team_info = get_team_from_list(
team_id=key["team_id"], table_name="team" team_list=teams_1, team_id=key["team_id"]
) )
team_alias = getattr(team_info, "team_alias", None) if team_info is not None:
key["team_alias"] = team_alias team_alias = getattr(team_info, "team_alias", None)
key["team_alias"] = team_alias
else:
key["team_alias"] = None
else: else:
key["team_alias"] = "None" key["team_alias"] = "None"
returned_keys.append(key) returned_keys.append(key)
response_data = { response_data = UserInfoResponse(
"user_id": user_id, user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list
"user_info": user_info, )
"keys": returned_keys,
"teams": team_list,
}
return response_data return response_data
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format( "litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
str(e) str(e)
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),

View file

@ -248,7 +248,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager, load_aws_secret_manager,
) )
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret, str_to_bool from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,
@ -2728,9 +2728,21 @@ async def startup_event():
### LOAD CONFIG ### ### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
verbose_proxy_logger.debug("worker_config: %s", worker_config) verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path # check if it's a valid file path
if worker_config is not None: if env_config_yaml is not None:
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
config_file_path=env_config_yaml
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=env_config_yaml
)
elif worker_config is not None:
if ( if (
isinstance(worker_config, str) isinstance(worker_config, str)
and os.path.isfile(worker_config) and os.path.isfile(worker_config)

View file

@ -31,6 +31,7 @@ from litellm import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
) )
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
from litellm.tests.test_streaming import streaming_format_tests from litellm.tests.test_streaming import streaming_format_tests
litellm.num_retries = 3 litellm.num_retries = 3
@ -2869,3 +2870,24 @@ def test_gemini_finetuned_endpoint(base_model, metadata):
assert mock_client.call_args.kwargs["url"].endswith( assert mock_client.call_args.kwargs["url"].endswith(
"endpoints/4965075652664360960:generateContent" "endpoints/4965075652664360960:generateContent"
) )
@pytest.mark.parametrize("api_base", ["", None, "my-custom-proxy-base"])
def test_custom_api_base(api_base):
stream = None
test_endpoint = "my-fake-endpoint"
vertex_base = VertexBase()
auth_header, url = vertex_base._check_custom_proxy(
api_base=api_base,
custom_llm_provider="gemini",
gemini_api_key="12324",
endpoint="",
stream=stream,
auth_header=None,
url="my-fake-endpoint",
)
if api_base:
assert url == api_base + ":"
else:
assert url == test_endpoint

View file

@ -1933,6 +1933,41 @@ async def test_openai_compatible_custom_api_base(provider):
assert "hello" in mock_call.call_args.kwargs["extra_body"] assert "hello" in mock_call.call_args.kwargs["extra_body"]
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk():
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hello world",
}
]
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
completion(
model="litellm_proxy/my-vllm-model",
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
assert "hello" in mock_call.call_args.kwargs["extra_body"]
# ################### Hugging Face Conversational models ######################## # ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv(): # def hf_test_completion_conv():
# try: # try:

View file

@ -256,7 +256,7 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route):
# check /user/info to verify user_role was set correctly # check /user/info to verify user_role was set correctly
new_user_info = await user_info(user_id=user_id) new_user_info = await user_info(user_id=user_id)
new_user_info = new_user_info["user_info"] new_user_info = new_user_info.user_info
print("new_user_info=", new_user_info) print("new_user_info=", new_user_info)
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
assert new_user_info.user_id == user_id assert new_user_info.user_id == user_id

View file

@ -2154,7 +2154,7 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call() # test_openai_chat_completion_complete_response_call()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], # ["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sync", "sync",