forked from phoenix/litellm-mirror
Merge pull request #5119 from BerriAI/litellm_add_gemini_context_caching_litellm
[Feat-Proxy] Add Support for VertexAI context caching
This commit is contained in:
commit
e671ae58e3
7 changed files with 279 additions and 36 deletions
|
@ -427,6 +427,105 @@ print(resp)
|
|||
```
|
||||
|
||||
|
||||
### **Context Caching**
|
||||
|
||||
Use Vertex AI Context Caching
|
||||
|
||||
[**Relevant VertexAI Docs**](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview)
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||
|
||||
1. Add model to config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
# used for /chat/completions, /completions, /embeddings endpoints
|
||||
- model_name: gemini-1.5-pro-001
|
||||
litellm_params:
|
||||
model: vertex_ai_beta/gemini-1.5-pro-001
|
||||
vertex_project: "project-id"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
|
||||
|
||||
# used for the /cachedContent and vertexAI native endpoints
|
||||
default_vertex_config:
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
|
||||
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request!
|
||||
|
||||
- First create a cachedContents object by calling the Vertex `cachedContents` endpoint. [VertexAI API Ref for cachedContents endpoint](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest). (LiteLLM proxy forwards the `/cachedContents` request to the VertexAI API)
|
||||
- Use the `cachedContents` object in your /chat/completions request to vertexAI
|
||||
|
||||
```python
|
||||
import datetime
|
||||
import openai
|
||||
import httpx
|
||||
|
||||
# Set Litellm proxy variables here
|
||||
LITELLM_BASE_URL = "http://0.0.0.0:4000"
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
|
||||
client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
|
||||
httpx_client = httpx.Client(timeout=30)
|
||||
|
||||
################################
|
||||
# First create a cachedContents object
|
||||
# this request gets forwarded as is to: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
||||
print("creating cached content")
|
||||
create_cache = httpx_client.post(
|
||||
url=f"{LITELLM_BASE_URL}/vertex-ai/cachedContents",
|
||||
headers = {"Authorization": f"Bearer {LITELLM_PROXY_API_KEY}"},
|
||||
json = {
|
||||
"model": "gemini-1.5-pro-001",
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"text": "This is sample text to demonstrate explicit caching."*4000
|
||||
}]
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
print("response from create_cache", create_cache)
|
||||
create_cache_response = create_cache.json()
|
||||
print("json from create_cache", create_cache_response)
|
||||
cached_content_name = create_cache_response["name"]
|
||||
|
||||
#################################
|
||||
# Use the `cachedContents` object in your /chat/completions
|
||||
response = client.chat.completions.create( # type: ignore
|
||||
model="gemini-1.5-pro-001",
|
||||
max_tokens=8192,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the sample text about?",
|
||||
},
|
||||
],
|
||||
temperature="0.7",
|
||||
extra_body={"cached_content": cached_content_name}, # 👈 key change
|
||||
)
|
||||
|
||||
print("response from proxy", response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Pre-requisites
|
||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||
* Authentication:
|
||||
|
|
|
@ -278,6 +278,14 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "countTokens" in request_route:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "cachedContents" in request_route:
|
||||
_model = request_data.get("model")
|
||||
if _model is not None and "/publishers/google/models/" not in _model:
|
||||
request_data["model"] = (
|
||||
f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}"
|
||||
)
|
||||
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported Vertex AI request route: {request_route}")
|
||||
if self.async_handler is None:
|
||||
|
|
|
@ -881,6 +881,21 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return self._credentials.token, self.project_id
|
||||
|
||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
VertexAI only supports ContextCaching on v1beta1
|
||||
|
||||
use this helper to decide if request should be sent to v1 or v1beta1
|
||||
|
||||
Returns v1beta1 if context caching is enabled
|
||||
Returns v1 in all other cases
|
||||
"""
|
||||
if "cached_content" in optional_params:
|
||||
return True
|
||||
if "CachedContent" in optional_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -891,6 +906,7 @@ class VertexLLM(BaseLLM):
|
|||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
@ -920,12 +936,13 @@ class VertexLLM(BaseLLM):
|
|||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
if (
|
||||
api_base is not None
|
||||
|
@ -1055,6 +1072,9 @@ class VertexLLM(BaseLLM):
|
|||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
|
@ -1064,6 +1084,7 @@ class VertexLLM(BaseLLM):
|
|||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
|
|
@ -3,20 +3,14 @@ model_list:
|
|||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: fireworks-llama-v3-70b-instruct
|
||||
litellm_params:
|
||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||
api_key: "os.environ/FIREWORKS"
|
||||
# provider specific wildcard routing
|
||||
- model_name: "anthropic/*"
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "anthropic/*"
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: "groq/*"
|
||||
litellm_params:
|
||||
model: "groq/*"
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
model: "*"
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
@ -25,37 +19,22 @@ model_list:
|
|||
litellm_params:
|
||||
model: mistral/mistral-small-latest
|
||||
api_key: "os.environ/MISTRAL_API_KEY"
|
||||
- model_name: tts
|
||||
- model_name: gemini-1.5-pro-001
|
||||
litellm_params:
|
||||
model: openai/tts-1
|
||||
api_key: "os.environ/OPENAI_API_KEY"
|
||||
model_info:
|
||||
mode: audio_speech
|
||||
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model: vertex_ai_beta/gemini-1.5-pro-001
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
|
||||
# Add path to service account.json
|
||||
|
||||
default_vertex_config:
|
||||
vertex_project: "adroit-crow-413218"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
|
||||
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
pass_through_endpoints:
|
||||
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
|
||||
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
|
||||
headers: # headers to forward to this URL
|
||||
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||
accept: application/json
|
||||
forward_headers: True
|
||||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["otel"] # 👈 KEY CHANGE
|
||||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
54
litellm/proxy/tests/test_gemini_context_caching.py
Normal file
54
litellm/proxy/tests/test_gemini_context_caching.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import datetime
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
# Set Litellm proxy variables here
|
||||
LITELLM_BASE_URL = "http://0.0.0.0:4000"
|
||||
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||
|
||||
client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
|
||||
httpx_client = httpx.Client(timeout=30)
|
||||
|
||||
################################
|
||||
# First create a cachedContents object
|
||||
print("creating cached content")
|
||||
create_cache = httpx_client.post(
|
||||
url=f"{LITELLM_BASE_URL}/vertex-ai/cachedContents",
|
||||
headers={"Authorization": f"Bearer {LITELLM_PROXY_API_KEY}"},
|
||||
json={
|
||||
"model": "gemini-1.5-pro-001",
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"text": "This is sample text to demonstrate explicit caching."
|
||||
* 4000
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
print("response from create_cache", create_cache)
|
||||
create_cache_response = create_cache.json()
|
||||
print("json from create_cache", create_cache_response)
|
||||
cached_content_name = create_cache_response["name"]
|
||||
|
||||
#################################
|
||||
# Use the `cachedContents` object in your /chat/completions
|
||||
response = client.chat.completions.create( # type: ignore
|
||||
model="gemini-1.5-pro-001",
|
||||
max_tokens=8192,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the sample text about?",
|
||||
},
|
||||
],
|
||||
temperature="0.7",
|
||||
extra_body={"cached_content": cached_content_name}, # 👈 key change
|
||||
)
|
||||
|
||||
print("response from proxy", response)
|
|
@ -303,3 +303,30 @@ async def vertex_cancel_fine_tuning_job(
|
|||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/cachedContents",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_add_cached_content(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/cachedContents",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
|
|
@ -1969,3 +1969,58 @@ def test_prompt_factory_nested():
|
|||
assert isinstance(
|
||||
message["parts"][0]["text"], str
|
||||
), "'text' value not a string."
|
||||
|
||||
|
||||
def test_get_token_url():
|
||||
from litellm.llms.vertex_httpx import VertexLLM
|
||||
|
||||
vertex_llm = VertexLLM()
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "us-central1"
|
||||
json_obj = get_vertex_ai_creds_json()
|
||||
vertex_credentials = json.dumps(json_obj)
|
||||
|
||||
should_use_v1beta1_features = vertex_llm.is_using_v1beta1_features(
|
||||
optional_params={"cached_content": "hi"}
|
||||
)
|
||||
|
||||
assert should_use_v1beta1_features is True
|
||||
|
||||
_, url = vertex_llm._get_token_and_url(
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key="",
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
api_base=None,
|
||||
model="",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
print("url=", url)
|
||||
|
||||
assert "/v1beta1/" in url
|
||||
|
||||
should_use_v1beta1_features = vertex_llm.is_using_v1beta1_features(
|
||||
optional_params={"temperature": 0.1}
|
||||
)
|
||||
|
||||
_, url = vertex_llm._get_token_and_url(
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key="",
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
api_base=None,
|
||||
model="",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
print("url for normal request", url)
|
||||
|
||||
assert "v1beta1" not in url
|
||||
assert "/v1/" in url
|
||||
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue