forked from phoenix/litellm-mirror
Merge pull request #5391 from BerriAI/litellm_add_ai21_support
[Feat] Add Vertex AI21 support
This commit is contained in:
commit
6ab8cbc105
10 changed files with 390 additions and 64 deletions
|
@ -983,6 +983,85 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## AI21 Models
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|------------------|--------------------------------------|
|
||||||
|
| jamba-1.5-mini@001 | `completion(model='vertex_ai/jamba-1.5-mini@001', messages)` |
|
||||||
|
| jamba-1.5-large@001 | `completion(model='vertex_ai/jamba-1.5-large@001', messages)` |
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
|
||||||
|
|
||||||
|
model = "meta/jamba-1.5-mini@001"
|
||||||
|
|
||||||
|
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
|
||||||
|
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/" + model,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
vertex_ai_project=vertex_ai_project,
|
||||||
|
vertex_ai_location=vertex_ai_location,
|
||||||
|
)
|
||||||
|
print("\nModel Response", response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="Proxy">
|
||||||
|
|
||||||
|
**1. Add to config**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: jamba-1.5-mini
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/jamba-1.5-mini@001
|
||||||
|
vertex_ai_project: "my-test-project"
|
||||||
|
vertex_ai_location: "us-east-1"
|
||||||
|
- model_name: jamba-1.5-large
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/jamba-1.5-large@001
|
||||||
|
vertex_ai_project: "my-test-project"
|
||||||
|
vertex_ai_location: "us-west-1"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Start proxy**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING at http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "jamba-1.5-large",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Usage - Codestral FIM
|
### Usage - Codestral FIM
|
||||||
|
|
||||||
|
|
|
@ -859,9 +859,12 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||||
VertexAIAnthropicConfig,
|
VertexAIAnthropicConfig,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||||
VertexAILlama3Config,
|
VertexAILlama3Config,
|
||||||
)
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
|
||||||
|
VertexAIAi21Config,
|
||||||
|
)
|
||||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
|
|
|
@ -48,6 +48,7 @@ def cost_router(
|
||||||
"claude" in model
|
"claude" in model
|
||||||
or "llama" in model
|
or "llama" in model
|
||||||
or "mistral" in model
|
or "mistral" in model
|
||||||
|
or "jamba" in model
|
||||||
or "codestral" in model
|
or "codestral" in model
|
||||||
):
|
):
|
||||||
return "cost_per_token"
|
return "cost_per_token"
|
||||||
|
|
|
@ -363,6 +363,12 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
raise DatabricksError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=response_json,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
response = ModelResponse(**response_json)
|
response = ModelResponse(**response_json)
|
||||||
|
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
import types
|
||||||
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIAi21Config:
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||||
|
|
||||||
|
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
|
||||||
|
|
||||||
|
-> Supports all OpenAI parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
):
|
||||||
|
return litellm.OpenAIConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
|
@ -0,0 +1,59 @@
|
||||||
|
import types
|
||||||
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAILlama3Config:
|
||||||
|
"""
|
||||||
|
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||||
|
|
||||||
|
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `max_tokens` Required (integer) max tokens,
|
||||||
|
|
||||||
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key == "max_tokens" and value is None:
|
||||||
|
value = self.max_tokens
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
):
|
||||||
|
return litellm.OpenAIConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
|
@ -1,6 +1,7 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Handler for calling llama 3.1 API on Vertex AI
|
## API Handler for calling Vertex AI Partner Models
|
||||||
import types
|
import types
|
||||||
|
from enum import Enum
|
||||||
from typing import Callable, Literal, Optional, Union
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
@ -8,7 +9,13 @@ import httpx # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
from ..base import BaseLLM
|
from ...base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class VertexPartnerProvider(str, Enum):
|
||||||
|
mistralai = "mistralai"
|
||||||
|
llama = "llama"
|
||||||
|
ai21 = "ai21"
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -24,61 +31,6 @@ class VertexAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class VertexAILlama3Config:
|
|
||||||
"""
|
|
||||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
|
||||||
|
|
||||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
|
||||||
|
|
||||||
- `max_tokens` Required (integer) max tokens,
|
|
||||||
|
|
||||||
Note: Please make sure to modify the default parameters as required for your use case.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key == "max_tokens" and value is None:
|
|
||||||
value = self.max_tokens
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
|
||||||
):
|
|
||||||
return litellm.OpenAIConfig().map_openai_params(
|
|
||||||
non_default_params=non_default_params,
|
|
||||||
optional_params=optional_params,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIPartnerModels(BaseLLM):
|
class VertexAIPartnerModels(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -87,17 +39,22 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
self,
|
self,
|
||||||
vertex_location: str,
|
vertex_location: str,
|
||||||
vertex_project: str,
|
vertex_project: str,
|
||||||
partner: Literal["llama", "mistralai"],
|
partner: VertexPartnerProvider,
|
||||||
stream: Optional[bool],
|
stream: Optional[bool],
|
||||||
model: str,
|
model: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
if partner == "llama":
|
if partner == VertexPartnerProvider.llama:
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||||
elif partner == "mistralai":
|
elif partner == VertexPartnerProvider.mistralai:
|
||||||
if stream:
|
if stream:
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||||
else:
|
else:
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||||
|
elif partner == VertexPartnerProvider.ai21:
|
||||||
|
if stream:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -160,9 +117,12 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
|
||||||
if "llama" in model:
|
if "llama" in model:
|
||||||
partner = "llama"
|
partner = VertexPartnerProvider.llama
|
||||||
elif "mistral" in model or "codestral" in model:
|
elif "mistral" in model or "codestral" in model:
|
||||||
partner = "mistralai"
|
partner = VertexPartnerProvider.mistralai
|
||||||
|
optional_params["custom_endpoint"] = True
|
||||||
|
elif "jamba" in model:
|
||||||
|
partner = VertexPartnerProvider.ai21
|
||||||
optional_params["custom_endpoint"] = True
|
optional_params["custom_endpoint"] = True
|
||||||
|
|
||||||
api_base = self.create_vertex_url(
|
api_base = self.create_vertex_url(
|
|
@ -126,7 +126,7 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
vertex_ai_non_gemini,
|
vertex_ai_non_gemini,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||||
VertexAIPartnerModels,
|
VertexAIPartnerModels,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
|
@ -2080,6 +2080,7 @@ def completion(
|
||||||
model.startswith("meta/")
|
model.startswith("meta/")
|
||||||
or model.startswith("mistral")
|
or model.startswith("mistral")
|
||||||
or model.startswith("codestral")
|
or model.startswith("codestral")
|
||||||
|
or model.startswith("jamba")
|
||||||
):
|
):
|
||||||
model_response = vertex_partner_models_chat_completion.completion(
|
model_response = vertex_partner_models_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -2356,3 +2356,157 @@ async def test_gemini_context_caching_anthropic_format():
|
||||||
|
|
||||||
check_cache_mock.assert_called_once()
|
check_cache_mock.assert_called_once()
|
||||||
assert mock_client.call_count == 3
|
assert mock_client.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partner_models_httpx_ai21():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model = "vertex_ai/jamba-1.5-mini@001"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, can you tell me the weather in San Francisco?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"top_p": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "chat-3d11cf95eb224966937b216d9494fe73",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": " Sure, let me check that for you.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "b5cef16b-5946-4937-b9d5-beeaea871e77",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "San Francisco"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 158,
|
||||||
|
"completion_tokens": 36,
|
||||||
|
"total_tokens": 194,
|
||||||
|
},
|
||||||
|
"meta": {"requestDurationMillis": 501},
|
||||||
|
"model": "jamba-1.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
url, kwargs = mock_post.call_args
|
||||||
|
print("url = ", url)
|
||||||
|
print("call args = ", kwargs)
|
||||||
|
|
||||||
|
print(kwargs["data"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
url[0]
|
||||||
|
== "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/ai21/models/jamba-1.5-mini@001:rawPredict"
|
||||||
|
)
|
||||||
|
|
||||||
|
# json loads kwargs
|
||||||
|
kwargs["data"] = json.loads(kwargs["data"])
|
||||||
|
|
||||||
|
assert kwargs["data"] == {
|
||||||
|
"model": "jamba-1.5-mini",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, can you tell me the weather in San Francisco?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"top_p": 0.5,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert response.id == "chat-3d11cf95eb224966937b216d9494fe73"
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content == " Sure, let me check that for you."
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments
|
||||||
|
== '{"location": "San Francisco"}'
|
||||||
|
)
|
||||||
|
assert response.usage.prompt_tokens == 158
|
||||||
|
assert response.usage.completion_tokens == 36
|
||||||
|
assert response.usage.total_tokens == 194
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
|
@ -3267,6 +3267,16 @@ def get_optional_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.VertexAIAi21Config().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue