Merge pull request #5119 from BerriAI/litellm_add_gemini_context_caching_litellm

[Feat-Proxy] Add Support for VertexAI context caching
This commit is contained in:
Ishaan Jaff 2024-08-08 16:08:58 -07:00 committed by GitHub
commit e671ae58e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 279 additions and 36 deletions

View file

@ -427,6 +427,105 @@ print(resp)
```
### **Context Caching**
Use Vertex AI Context Caching
[**Relevant VertexAI Docs**](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview)
<Tabs>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
# used for /chat/completions, /completions, /embeddings endpoints
- model_name: gemini-1.5-pro-001
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001
vertex_project: "project-id"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
# used for the /cachedContent and vertexAI native endpoints
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request!
- First create a cachedContents object by calling the Vertex `cachedContents` endpoint. [VertexAI API Ref for cachedContents endpoint](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest). (LiteLLM proxy forwards the `/cachedContents` request to the VertexAI API)
- Use the `cachedContents` object in your /chat/completions request to vertexAI
```python
import datetime
import openai
import httpx
# Set Litellm proxy variables here
LITELLM_BASE_URL = "http://0.0.0.0:4000"
LITELLM_PROXY_API_KEY = "sk-1234"
client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
httpx_client = httpx.Client(timeout=30)
################################
# First create a cachedContents object
# this request gets forwarded as is to: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
print("creating cached content")
create_cache = httpx_client.post(
url=f"{LITELLM_BASE_URL}/vertex-ai/cachedContents",
headers = {"Authorization": f"Bearer {LITELLM_PROXY_API_KEY}"},
json = {
"model": "gemini-1.5-pro-001",
"contents": [
{
"role": "user",
"parts": [{
"text": "This is sample text to demonstrate explicit caching."*4000
}]
}
],
}
)
print("response from create_cache", create_cache)
create_cache_response = create_cache.json()
print("json from create_cache", create_cache_response)
cached_content_name = create_cache_response["name"]
#################################
# Use the `cachedContents` object in your /chat/completions
response = client.chat.completions.create( # type: ignore
model="gemini-1.5-pro-001",
max_tokens=8192,
messages=[
{
"role": "user",
"content": "what is the sample text about?",
},
],
temperature="0.7",
extra_body={"cached_content": cached_content_name}, # 👈 key change
)
print("response from proxy", response)
```
</TabItem>
</Tabs>
## Pre-requisites
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
* Authentication:

View file

@ -278,6 +278,14 @@ class VertexFineTuningAPI(VertexLLM):
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "countTokens" in request_route:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "cachedContents" in request_route:
_model = request_data.get("model")
if _model is not None and "/publishers/google/models/" not in _model:
request_data["model"] = (
f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}"
)
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
else:
raise ValueError(f"Unsupported Vertex AI request route: {request_route}")
if self.async_handler is None:

View file

@ -881,6 +881,21 @@ class VertexLLM(BaseLLM):
return self._credentials.token, self.project_id
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
"""
VertexAI only supports ContextCaching on v1beta1
use this helper to decide if request should be sent to v1 or v1beta1
Returns v1beta1 if context caching is enabled
Returns v1 in all other cases
"""
if "cached_content" in optional_params:
return True
if "CachedContent" in optional_params:
return True
return False
def _get_token_and_url(
self,
model: str,
@ -891,6 +906,7 @@ class VertexLLM(BaseLLM):
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
@ -920,12 +936,13 @@ class VertexLLM(BaseLLM):
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if (
api_base is not None
@ -1055,6 +1072,9 @@ class VertexLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]:
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@ -1064,6 +1084,7 @@ class VertexLLM(BaseLLM):
stream=stream,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
## TRANSFORMATION ##

View file

@ -3,20 +3,14 @@ model_list:
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
# provider specific wildcard routing
- model_name: "anthropic/*"
- model_name: "*"
litellm_params:
model: "anthropic/*"
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: "groq/*"
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
model: "*"
- model_name: "*"
litellm_params:
model: openai/*
@ -25,37 +19,22 @@ model_list:
litellm_params:
model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY"
- model_name: tts
- model_name: gemini-1.5-pro-001
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
model: vertex_ai_beta/gemini-1.5-pro-001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
# Add path to service account.json
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
headers: # headers to forward to this URL
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
forward_headers: True
litellm_settings:
callbacks: ["otel"] # 👈 KEY CHANGE
success_callback: ["prometheus"]
failure_callback: ["prometheus"]

View file

@ -0,0 +1,54 @@
import datetime
import httpx
import openai
# Set Litellm proxy variables here
LITELLM_BASE_URL = "http://0.0.0.0:4000"
LITELLM_PROXY_API_KEY = "sk-1234"
client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
httpx_client = httpx.Client(timeout=30)
################################
# First create a cachedContents object
print("creating cached content")
create_cache = httpx_client.post(
url=f"{LITELLM_BASE_URL}/vertex-ai/cachedContents",
headers={"Authorization": f"Bearer {LITELLM_PROXY_API_KEY}"},
json={
"model": "gemini-1.5-pro-001",
"contents": [
{
"role": "user",
"parts": [
{
"text": "This is sample text to demonstrate explicit caching."
* 4000
}
],
}
],
},
)
print("response from create_cache", create_cache)
create_cache_response = create_cache.json()
print("json from create_cache", create_cache_response)
cached_content_name = create_cache_response["name"]
#################################
# Use the `cachedContents` object in your /chat/completions
response = client.chat.completions.create( # type: ignore
model="gemini-1.5-pro-001",
max_tokens=8192,
messages=[
{
"role": "user",
"content": "what is the sample text about?",
},
],
temperature="0.7",
extra_body={"cached_content": cached_content_name}, # 👈 key change
)
print("response from proxy", response)

View file

@ -303,3 +303,30 @@ async def vertex_cancel_fine_tuning_job(
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/cachedContents",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_add_cached_content(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/cachedContents",
)
return response
except Exception as e:
raise exception_handler(e) from e

View file

@ -1969,3 +1969,58 @@ def test_prompt_factory_nested():
assert isinstance(
message["parts"][0]["text"], str
), "'text' value not a string."
def test_get_token_url():
from litellm.llms.vertex_httpx import VertexLLM
vertex_llm = VertexLLM()
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "us-central1"
json_obj = get_vertex_ai_creds_json()
vertex_credentials = json.dumps(json_obj)
should_use_v1beta1_features = vertex_llm.is_using_v1beta1_features(
optional_params={"cached_content": "hi"}
)
assert should_use_v1beta1_features is True
_, url = vertex_llm._get_token_and_url(
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,
)
print("url=", url)
assert "/v1beta1/" in url
should_use_v1beta1_features = vertex_llm.is_using_v1beta1_features(
optional_params={"temperature": 0.1}
)
_, url = vertex_llm._get_token_and_url(
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,
)
print("url for normal request", url)
assert "v1beta1" not in url
assert "/v1/" in url
pass