mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Support litellm.api_base
for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
* test(tests): add unit testing for litellm_proxy integration * fix(cost_calculator.py): fix tracking cost in sdk when calling proxy * fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes * fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion * feat(vertex_ai/): test * fix: fix linting error * test: set api base as None before starting loadtest
This commit is contained in:
parent
8657816477
commit
6fd18651d1
10 changed files with 223 additions and 43 deletions
|
@ -828,11 +828,14 @@ def get_response_cost_from_hidden_params(
|
||||||
_hidden_params_dict = hidden_params
|
_hidden_params_dict = hidden_params
|
||||||
|
|
||||||
additional_headers = _hidden_params_dict.get("additional_headers", {})
|
additional_headers = _hidden_params_dict.get("additional_headers", {})
|
||||||
if additional_headers and "x-litellm-response-cost" in additional_headers:
|
if (
|
||||||
response_cost = additional_headers["x-litellm-response-cost"]
|
additional_headers
|
||||||
|
and "llm_provider-x-litellm-response-cost" in additional_headers
|
||||||
|
):
|
||||||
|
response_cost = additional_headers["llm_provider-x-litellm-response-cost"]
|
||||||
if response_cost is None:
|
if response_cost is None:
|
||||||
return None
|
return None
|
||||||
return float(additional_headers["x-litellm-response-cost"])
|
return float(additional_headers["llm_provider-x-litellm-response-cost"])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,9 @@ def get_supports_response_schema(
|
||||||
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
|
all_gemini_url_modes = Literal[
|
||||||
|
"chat", "embedding", "batch_embedding", "image_generation"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_vertex_url(
|
def _get_vertex_url(
|
||||||
|
@ -91,7 +93,11 @@ def _get_vertex_url(
|
||||||
if model.isdigit():
|
if model.isdigit():
|
||||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||||
|
elif mode == "image_generation":
|
||||||
|
endpoint = "predict"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
if model.isdigit():
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||||
if not url or not endpoint:
|
if not url or not endpoint:
|
||||||
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
||||||
return url, endpoint
|
return url, endpoint
|
||||||
|
@ -127,6 +133,10 @@ def _get_gemini_url(
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
_gemini_model_name, endpoint, gemini_api_key
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
)
|
)
|
||||||
|
elif mode == "image_generation":
|
||||||
|
raise ValueError(
|
||||||
|
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
|
||||||
|
)
|
||||||
|
|
||||||
return url, endpoint
|
return url, endpoint
|
||||||
|
|
||||||
|
|
|
@ -43,22 +43,23 @@ class VertexImageGeneration(VertexLLM):
|
||||||
def image_generation(
|
def image_generation(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
api_base: Optional[str],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
model: Optional[
|
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
str
|
|
||||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
|
||||||
client: Optional[Any] = None,
|
client: Optional[Any] = None,
|
||||||
optional_params: Optional[dict] = None,
|
optional_params: Optional[dict] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
aimg_generation=False,
|
aimg_generation=False,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
if aimg_generation is True:
|
if aimg_generation is True:
|
||||||
return self.aimage_generation( # type: ignore
|
return self.aimage_generation( # type: ignore
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
api_base=api_base,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
|
@ -83,13 +84,27 @@ class VertexImageGeneration(VertexLLM):
|
||||||
else:
|
else:
|
||||||
sync_handler = client # type: ignore
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||||
|
|
||||||
|
auth_header: Optional[str] = None
|
||||||
auth_header, _ = self._ensure_access_token(
|
auth_header, _ = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=None,
|
||||||
|
auth_header=auth_header,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=False,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=False,
|
||||||
|
mode="image_generation",
|
||||||
|
)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
} # default optional params
|
} # default optional params
|
||||||
|
@ -99,31 +114,21 @@ class VertexImageGeneration(VertexLLM):
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key=None,
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": optional_params,
|
"complete_input_dict": optional_params,
|
||||||
"request_str": request_str,
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = sync_handler.post(
|
response = sync_handler.post(
|
||||||
url=url,
|
url=api_base,
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
},
|
|
||||||
data=json.dumps(request_data),
|
data=json.dumps(request_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -138,17 +143,17 @@ class VertexImageGeneration(VertexLLM):
|
||||||
async def aimage_generation(
|
async def aimage_generation(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
api_base: Optional[str],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
model_response: litellm.ImageResponse,
|
model_response: litellm.ImageResponse,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
model: Optional[
|
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
str
|
|
||||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
optional_params: Optional[dict] = None,
|
optional_params: Optional[dict] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
if client is None:
|
if client is None:
|
||||||
|
@ -169,7 +174,6 @@ class VertexImageGeneration(VertexLLM):
|
||||||
|
|
||||||
# make POST request to
|
# make POST request to
|
||||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||||
|
@ -188,11 +192,25 @@ class VertexImageGeneration(VertexLLM):
|
||||||
} \
|
} \
|
||||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
"""
|
"""
|
||||||
|
auth_header: Optional[str] = None
|
||||||
auth_header, _ = self._ensure_access_token(
|
auth_header, _ = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=None,
|
||||||
|
auth_header=auth_header,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=False,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=False,
|
||||||
|
mode="image_generation",
|
||||||
|
)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
} # default optional params
|
} # default optional params
|
||||||
|
@ -202,22 +220,21 @@ class VertexImageGeneration(VertexLLM):
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key=None,
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": optional_params,
|
"complete_input_dict": optional_params,
|
||||||
"request_str": request_str,
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
url=url,
|
url=api_base,
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
},
|
|
||||||
data=json.dumps(request_data),
|
data=json.dumps(request_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,7 @@ class VertexEmbedding(VertexBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore
|
response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
|
|
|
@ -2350,6 +2350,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = api_base or litellm.api_base or get_secret("GEMINI_API_BASE")
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
response = vertex_chat_completion.completion( # type: ignore
|
response = vertex_chat_completion.completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2392,6 +2394,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = api_base or litellm.api_base or get_secret("VERTEXAI_API_BASE")
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
if (
|
if (
|
||||||
model.startswith("meta/")
|
model.startswith("meta/")
|
||||||
|
@ -3657,6 +3661,8 @@ def embedding( # noqa: PLR0915
|
||||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = api_base or litellm.api_base or get_secret_str("GEMINI_API_BASE")
|
||||||
|
|
||||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -3671,6 +3677,8 @@ def embedding( # noqa: PLR0915
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_key=gemini_api_key,
|
api_key=gemini_api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
@ -3695,6 +3703,13 @@ def embedding( # noqa: PLR0915
|
||||||
or get_secret_str("VERTEX_CREDENTIALS")
|
or get_secret_str("VERTEX_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("VERTEXAI_API_BASE")
|
||||||
|
or get_secret_str("VERTEX_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"image" in optional_params
|
"image" in optional_params
|
||||||
or "video" in optional_params
|
or "video" in optional_params
|
||||||
|
@ -3716,6 +3731,7 @@ def embedding( # noqa: PLR0915
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
client=client,
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = vertex_embedding.embedding(
|
response = vertex_embedding.embedding(
|
||||||
|
@ -3733,6 +3749,8 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "oobabooga":
|
elif custom_llm_provider == "oobabooga":
|
||||||
response = oobabooga.embedding(
|
response = oobabooga.embedding(
|
||||||
|
@ -4695,6 +4713,14 @@ def image_generation( # noqa: PLR0915
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("VERTEXAI_API_BASE")
|
||||||
|
or get_secret_str("VERTEX_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
model_response = vertex_image_generation.image_generation(
|
model_response = vertex_image_generation.image_generation(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -4706,6 +4732,8 @@ def image_generation( # noqa: PLR0915
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
|
api_base=api_base,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider in litellm._custom_providers
|
custom_llm_provider in litellm._custom_providers
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-3.5-turbo"
|
- model_name: "gpt-4o"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
|
@ -15,9 +15,11 @@ from pydantic import BaseModel
|
||||||
from litellm.cost_calculator import response_cost_calculator
|
from litellm.cost_calculator import response_cost_calculator
|
||||||
|
|
||||||
|
|
||||||
def test_cost_calculator():
|
def test_cost_calculator_with_response_cost_in_additional_headers():
|
||||||
class MockResponse(BaseModel):
|
class MockResponse(BaseModel):
|
||||||
_hidden_params = {"additional_headers": {"x-litellm-response-cost": 1000}}
|
_hidden_params = {
|
||||||
|
"additional_headers": {"llm_provider-x-litellm-response-cost": 1000}
|
||||||
|
}
|
||||||
|
|
||||||
result = response_cost_calculator(
|
result = response_cost_calculator(
|
||||||
response_object=MockResponse(),
|
response_object=MockResponse(),
|
||||||
|
|
|
@ -31,7 +31,7 @@ async def test_litellm_gateway_from_sdk():
|
||||||
openai_client = OpenAI(api_key="fake-key")
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
openai_client.chat.completions, "create", new=MagicMock()
|
openai_client.chat.completions.with_raw_response, "create", new=MagicMock()
|
||||||
) as mock_call:
|
) as mock_call:
|
||||||
try:
|
try:
|
||||||
completion(
|
completion(
|
||||||
|
@ -374,3 +374,78 @@ async def test_litellm_gateway_from_sdk_rerank(is_async):
|
||||||
assert request_body["query"] == "What is machine learning?"
|
assert request_body["query"] == "What is machine learning?"
|
||||||
assert request_body["model"] == "rerank-english-v2.0"
|
assert request_body["model"] == "rerank-english-v2.0"
|
||||||
assert len(request_body["documents"]) == 2
|
assert len(request_body["documents"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_litellm_gateway_from_sdk_with_response_cost_in_additional_headers():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
|
# Create mock response object
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.headers = {"x-litellm-response-cost": "120"}
|
||||||
|
mock_response.parse.return_value = litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-BEkxQvRGp9VAushfAsOZCbhMFLsoy",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"message": {
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
"refusal": None,
|
||||||
|
"role": "assistant",
|
||||||
|
"annotations": [],
|
||||||
|
"audio": None,
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1742856796,
|
||||||
|
"model": "gpt-4o-2024-08-06",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"service_tier": "default",
|
||||||
|
"system_fingerprint": "fp_6ec83003ad",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"total_tokens": 19,
|
||||||
|
"completion_tokens_details": {
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.chat.completions.with_raw_response,
|
||||||
|
"create",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_call:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="litellm_proxy/gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
api_base="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-PIp1h0RekR",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the headers were properly passed through
|
||||||
|
print(f"additional_headers: {response._hidden_params['additional_headers']}")
|
||||||
|
assert (
|
||||||
|
response._hidden_params["additional_headers"][
|
||||||
|
"llm_provider-x-litellm-response-cost"
|
||||||
|
]
|
||||||
|
== "120"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response._hidden_params["response_cost"] == 120
|
||||||
|
|
|
@ -109,12 +109,13 @@ def analyze_results(vertex_times):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_embedding_performance():
|
async def test_embedding_performance(monkeypatch):
|
||||||
"""
|
"""
|
||||||
Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms
|
Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms
|
||||||
|
|
||||||
20 RPS for 20 seconds
|
20 RPS for 20 seconds
|
||||||
"""
|
"""
|
||||||
|
monkeypatch.setattr(litellm, "api_base", None)
|
||||||
duration_seconds = 20
|
duration_seconds = 20
|
||||||
requests_per_second = 20
|
requests_per_second = 20
|
||||||
vertex_times = await run_load_test(duration_seconds, requests_per_second)
|
vertex_times = await run_load_test(duration_seconds, requests_per_second)
|
||||||
|
|
|
@ -31,6 +31,7 @@ from litellm import (
|
||||||
completion,
|
completion,
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
|
image_generation,
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai.gemini.transformation import (
|
from litellm.llms.vertex_ai.gemini.transformation import (
|
||||||
_gemini_convert_messages_with_history,
|
_gemini_convert_messages_with_history,
|
||||||
|
@ -3327,3 +3328,46 @@ def test_signed_s3_url_with_format():
|
||||||
json_str = json.dumps(mock_client.call_args.kwargs["json"])
|
json_str = json.dumps(mock_client.call_args.kwargs["json"])
|
||||||
assert "image/jpeg" in json_str
|
assert "image/jpeg" in json_str
|
||||||
assert "image/png" not in json_str
|
assert "image/png" not in json_str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"])
|
||||||
|
@pytest.mark.parametrize("route", ["completion", "embedding", "image_generation"])
|
||||||
|
def test_litellm_api_base(monkeypatch, provider, route):
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "api_base", "https://litellm.com")
|
||||||
|
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
|
if route == "image_generation" and provider == "gemini":
|
||||||
|
pytest.skip("Gemini does not support image generation")
|
||||||
|
|
||||||
|
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||||
|
try:
|
||||||
|
if route == "completion":
|
||||||
|
response = completion(
|
||||||
|
model=f"{provider}/gemini-2.0-flash-001",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
elif route == "embedding":
|
||||||
|
response = embedding(
|
||||||
|
model=f"{provider}/gemini-2.0-flash-001",
|
||||||
|
input=["Hello, world!"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
elif route == "image_generation":
|
||||||
|
response = image_generation(
|
||||||
|
model=f"{provider}/gemini-2.0-flash-001",
|
||||||
|
prompt="Hello, world!",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
assert mock_client.call_args.kwargs["url"].startswith("https://litellm.com")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue