forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (09/20/2024) (#5807)
* fix(vertex_llm_base.py): Handle api_base = "" Fixes https://github.com/BerriAI/litellm/issues/5798 * fix(o1_transformation.py): handle stream_options not being supported https://github.com/BerriAI/litellm/issues/5803 * docs(routing.md): fix docs Closes https://github.com/BerriAI/litellm/issues/5808 * perf(internal_user_endpoints.py): reduce db calls for getting team_alias for a key Use the list gotten earlier in `/user/info` endpoint Reduces ui keys tab load time to 800ms (prev. 28s+) * feat(proxy_server.py): support CONFIG_FILE_PATH as env var Closes https://github.com/BerriAI/litellm/issues/5744 * feat(get_llm_provider_logic.py): add `litellm_proxy/` as a known openai-compatible route simplifies calling litellm proxy Reduces confusion when calling models on litellm proxy from litellm sdk * docs(litellm_proxy.md): cleanup docs * fix(internal_user_endpoints.py): fix pydantic obj * test(test_key_generate_prisma.py): fix test
This commit is contained in:
parent
c9ceab0f1e
commit
7ed6938a3f
14 changed files with 204 additions and 84 deletions
|
@ -9,13 +9,13 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
:::
|
||||
|
||||
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `openai/` prefix before the model
|
||||
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
|
||||
|
||||
## Required Variables
|
||||
|
||||
```python
|
||||
os.environ["OPENAI_API_KEY"] = "" # "sk-1234" your litellm proxy api key
|
||||
os.environ["OPENAI_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
|
||||
os.environ["LITELLM_PROXY_API_KEY"] = "" # "sk-1234" your litellm proxy api key
|
||||
os.environ["LITELLM_PROXY_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
|
||||
```
|
||||
|
||||
|
||||
|
@ -25,18 +25,18 @@ import os
|
|||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||
|
||||
# set custom api base to your proxy
|
||||
# either set .env or litellm.api_base
|
||||
# os.environ["OPENAI_API_BASE"] = ""
|
||||
# os.environ["LITELLM_PROXY_API_BASE"] = ""
|
||||
litellm.api_base = "your-openai-proxy-url"
|
||||
|
||||
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
|
||||
# openai call
|
||||
response = completion(model="openai/your-model-name", messages)
|
||||
# litellm proxy call
|
||||
response = completion(model="litellm_proxy/your-model-name", messages)
|
||||
```
|
||||
|
||||
## Usage - passing `api_base`, `api_key` per request
|
||||
|
@ -48,13 +48,13 @@ import os
|
|||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
|
||||
# openai call
|
||||
# litellm proxy call
|
||||
response = completion(
|
||||
model="openai/your-model-name",
|
||||
model="litellm_proxy/your-model-name",
|
||||
messages,
|
||||
api_base = "your-litellm-proxy-url",
|
||||
api_key = "your-litellm-proxy-api-key"
|
||||
|
@ -67,13 +67,13 @@ import os
|
|||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
|
||||
# openai call
|
||||
response = completion(
|
||||
model="openai/your-model-name",
|
||||
model="litellm_proxy/your-model-name",
|
||||
messages,
|
||||
api_base = "your-litellm-proxy-url",
|
||||
stream=True
|
||||
|
|
|
@ -183,9 +183,9 @@ model_list = [{ # list of model deployments
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
|
@ -193,24 +193,24 @@ model_list = [{ # list of model deployments
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
}]
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
|
||||
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||
enable_pre_call_checks=True, # enables router rate limits for concurrent calls
|
||||
)
|
||||
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
|
|
|
@ -495,6 +495,7 @@ openai_compatible_providers: List = [
|
|||
"friendliai",
|
||||
"azure_ai",
|
||||
"github",
|
||||
"litellm_proxy",
|
||||
]
|
||||
openai_text_completion_compatible_providers: List = (
|
||||
[ # providers that support `/v1/completions`
|
||||
|
@ -748,6 +749,7 @@ class LlmProviders(str, Enum):
|
|||
EMPOWER = "empower"
|
||||
GITHUB = "github"
|
||||
CUSTOM = "custom"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
|
||||
|
||||
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
||||
|
|
|
@ -243,6 +243,9 @@ def get_llm_provider(
|
|||
elif custom_llm_provider == "github":
|
||||
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
||||
elif custom_llm_provider == "litellm_proxy":
|
||||
api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY")
|
||||
|
||||
elif custom_llm_provider == "mistral":
|
||||
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
||||
|
|
|
@ -65,6 +65,7 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
|||
"top_logprobs",
|
||||
"response_format",
|
||||
"stop",
|
||||
"stream_options",
|
||||
]
|
||||
|
||||
return [
|
||||
|
|
|
@ -16,6 +16,7 @@ from litellm.types.llms.vertex_ai import (
|
|||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
separate_cached_messages,
|
||||
transform_openai_messages_to_gemini_context_caching,
|
||||
|
@ -24,7 +25,7 @@ from .transformation import (
|
|||
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
|
||||
|
||||
|
||||
class ContextCachingEndpoints:
|
||||
class ContextCachingEndpoints(VertexBase):
|
||||
"""
|
||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||
|
||||
|
@ -34,7 +35,7 @@ class ContextCachingEndpoints:
|
|||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_token_and_url(
|
||||
def _get_token_and_url_context_caching(
|
||||
self,
|
||||
gemini_api_key: Optional[str],
|
||||
custom_llm_provider: Literal["gemini"],
|
||||
|
@ -57,18 +58,16 @@ class ContextCachingEndpoints:
|
|||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if (
|
||||
api_base is not None
|
||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
if custom_llm_provider == "gemini":
|
||||
url = "{}/{}".format(api_base, endpoint)
|
||||
auth_header = (
|
||||
gemini_api_key # cloudflare expects api key as bearer token
|
||||
)
|
||||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
return auth_header, url
|
||||
return self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
gemini_api_key=gemini_api_key,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=auth_header,
|
||||
url=url,
|
||||
)
|
||||
|
||||
def check_cache(
|
||||
self,
|
||||
|
@ -90,7 +89,7 @@ class ContextCachingEndpoints:
|
|||
- None
|
||||
"""
|
||||
|
||||
_, url = self._get_token_and_url(
|
||||
_, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider="gemini",
|
||||
api_base=api_base,
|
||||
|
@ -125,11 +124,12 @@ class ContextCachingEndpoints:
|
|||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
if (
|
||||
cached_item.get("displayName") is not None
|
||||
and cached_item["displayName"] == cache_key
|
||||
):
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
return None
|
||||
|
@ -154,7 +154,7 @@ class ContextCachingEndpoints:
|
|||
- None
|
||||
"""
|
||||
|
||||
_, url = self._get_token_and_url(
|
||||
_, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider="gemini",
|
||||
api_base=api_base,
|
||||
|
@ -189,11 +189,12 @@ class ContextCachingEndpoints:
|
|||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
if (
|
||||
cached_item.get("displayName") is not None
|
||||
and cached_item["displayName"] == cache_key
|
||||
):
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
return None
|
||||
|
@ -224,7 +225,7 @@ class ContextCachingEndpoints:
|
|||
return messages, cached_content
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url(
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider="gemini",
|
||||
api_base=api_base,
|
||||
|
@ -329,7 +330,7 @@ class ContextCachingEndpoints:
|
|||
return messages, cached_content
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url(
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider="gemini",
|
||||
api_base=api_base,
|
||||
|
|
|
@ -86,7 +86,7 @@ class VertexBase(BaseLLM):
|
|||
)
|
||||
|
||||
if project_id is None:
|
||||
project_id = creds.project_id
|
||||
project_id = getattr(creds, "project_id", None)
|
||||
else:
|
||||
creds, creds_project_id = google_auth.default(
|
||||
quota_project_id=project_id,
|
||||
|
@ -95,7 +95,7 @@ class VertexBase(BaseLLM):
|
|||
if project_id is None:
|
||||
project_id = creds_project_id
|
||||
|
||||
creds.refresh(Request())
|
||||
creds.refresh(Request()) # type: ignore
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
@ -169,6 +169,39 @@ class VertexBase(BaseLLM):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _check_custom_proxy(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
gemini_api_key: Optional[str],
|
||||
endpoint: str,
|
||||
stream: Optional[bool],
|
||||
auth_header: Optional[str],
|
||||
url: str,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
|
||||
## Returns
|
||||
- (auth_header, url) - Tuple[Optional[str], str]
|
||||
"""
|
||||
if api_base:
|
||||
if custom_llm_provider == "gemini":
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
if gemini_api_key is None:
|
||||
raise ValueError(
|
||||
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||
)
|
||||
auth_header = (
|
||||
gemini_api_key # cloudflare expects api key as bearer token
|
||||
)
|
||||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
return auth_header, url
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -215,25 +248,15 @@ class VertexBase(BaseLLM):
|
|||
vertex_api_version=version,
|
||||
)
|
||||
|
||||
if (
|
||||
api_base is not None
|
||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
if custom_llm_provider == "gemini":
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
if gemini_api_key is None:
|
||||
raise ValueError(
|
||||
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||
)
|
||||
auth_header = (
|
||||
gemini_api_key # cloudflare expects api key as bearer token
|
||||
)
|
||||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
|
||||
return auth_header, url
|
||||
return self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
auth_header=auth_header,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
gemini_api_key=gemini_api_key,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
url=url,
|
||||
)
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self,
|
||||
|
|
|
@ -1422,6 +1422,13 @@ class UserAPIKeyAuth(
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class UserInfoResponse(LiteLLMBase):
|
||||
user_id: Optional[str]
|
||||
user_info: Optional[Union[dict, BaseModel]]
|
||||
keys: List
|
||||
teams: List
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMBase):
|
||||
param_name: str
|
||||
param_value: Dict
|
||||
|
|
|
@ -15,6 +15,7 @@ import copy
|
|||
import json
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
@ -274,10 +275,23 @@ async def ui_get_available_role(
|
|||
return _data_to_return
|
||||
|
||||
|
||||
def get_team_from_list(
|
||||
team_list: Optional[List[LiteLLM_TeamTable]], team_id: str
|
||||
) -> Optional[LiteLLM_TeamTable]:
|
||||
if team_list is None:
|
||||
return None
|
||||
|
||||
for team in team_list:
|
||||
if team.team_id == team_id:
|
||||
return team
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/info",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=UserInfoResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def user_info(
|
||||
|
@ -337,7 +351,7 @@ async def user_info(
|
|||
## GET ALL TEAMS ##
|
||||
team_list = []
|
||||
team_id_list = []
|
||||
# _DEPRECATED_ check if user in 'member' field
|
||||
# get all teams user belongs to
|
||||
teams_1 = await prisma_client.get_data(
|
||||
user_id=user_id, table_name="team", query_type="find_all"
|
||||
)
|
||||
|
@ -414,13 +428,13 @@ async def user_info(
|
|||
if (
|
||||
key.token == litellm_master_key_hash
|
||||
and general_settings.get("disable_master_key_return", False)
|
||||
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||
is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
except:
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
key = key.dict()
|
||||
if (
|
||||
|
@ -428,29 +442,29 @@ async def user_info(
|
|||
and key["team_id"] is not None
|
||||
and key["team_id"] != "litellm-dashboard"
|
||||
):
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=key["team_id"], table_name="team"
|
||||
team_info = get_team_from_list(
|
||||
team_list=teams_1, team_id=key["team_id"]
|
||||
)
|
||||
team_alias = getattr(team_info, "team_alias", None)
|
||||
key["team_alias"] = team_alias
|
||||
if team_info is not None:
|
||||
team_alias = getattr(team_info, "team_alias", None)
|
||||
key["team_alias"] = team_alias
|
||||
else:
|
||||
key["team_alias"] = None
|
||||
else:
|
||||
key["team_alias"] = "None"
|
||||
returned_keys.append(key)
|
||||
|
||||
response_data = {
|
||||
"user_id": user_id,
|
||||
"user_info": user_info,
|
||||
"keys": returned_keys,
|
||||
"teams": team_list,
|
||||
}
|
||||
response_data = UserInfoResponse(
|
||||
user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list
|
||||
)
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
|
|
@ -248,7 +248,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
|||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.secret_managers.main import get_secret, str_to_bool
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
|
@ -2728,9 +2728,21 @@ async def startup_event():
|
|||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if worker_config is not None:
|
||||
if env_config_yaml is not None:
|
||||
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||
config_file_path=env_config_yaml
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=env_config_yaml
|
||||
)
|
||||
elif worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
|
|
|
@ -31,6 +31,7 @@ from litellm import (
|
|||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
_gemini_convert_messages_with_history,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
|
||||
litellm.num_retries = 3
|
||||
|
@ -2869,3 +2870,24 @@ def test_gemini_finetuned_endpoint(base_model, metadata):
|
|||
assert mock_client.call_args.kwargs["url"].endswith(
|
||||
"endpoints/4965075652664360960:generateContent"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_base", ["", None, "my-custom-proxy-base"])
|
||||
def test_custom_api_base(api_base):
|
||||
stream = None
|
||||
test_endpoint = "my-fake-endpoint"
|
||||
vertex_base = VertexBase()
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="gemini",
|
||||
gemini_api_key="12324",
|
||||
endpoint="",
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url="my-fake-endpoint",
|
||||
)
|
||||
|
||||
if api_base:
|
||||
assert url == api_base + ":"
|
||||
else:
|
||||
assert url == test_endpoint
|
||||
|
|
|
@ -1933,6 +1933,41 @@ async def test_openai_compatible_custom_api_base(provider):
|
|||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_litellm_gateway_from_sdk():
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello world",
|
||||
}
|
||||
]
|
||||
from openai import OpenAI
|
||||
|
||||
openai_client = OpenAI(api_key="fake-key")
|
||||
|
||||
with patch.object(
|
||||
openai_client.chat.completions, "create", new=MagicMock()
|
||||
) as mock_call:
|
||||
try:
|
||||
completion(
|
||||
model="litellm_proxy/my-vllm-model",
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
client=openai_client,
|
||||
api_base="my-custom-api-base",
|
||||
hello="world",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||
|
||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||
|
||||
|
||||
# ################### Hugging Face Conversational models ########################
|
||||
# def hf_test_completion_conv():
|
||||
# try:
|
||||
|
|
|
@ -256,7 +256,7 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route):
|
|||
|
||||
# check /user/info to verify user_role was set correctly
|
||||
new_user_info = await user_info(user_id=user_id)
|
||||
new_user_info = new_user_info["user_info"]
|
||||
new_user_info = new_user_info.user_info
|
||||
print("new_user_info=", new_user_info)
|
||||
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
|
||||
assert new_user_info.user_id == user_id
|
||||
|
|
|
@ -2154,7 +2154,7 @@ def test_openai_chat_completion_complete_response_call():
|
|||
# test_openai_chat_completion_complete_response_call()
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], #
|
||||
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"sync",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue