Merge branch 'main' into litellm_anthropic_tool_calling_streaming_fix

This commit is contained in:
Krish Dholakia 2024-07-03 20:43:51 -07:00 committed by GitHub
commit 06c6c65d2a
24 changed files with 868 additions and 508 deletions

View file

@ -15,7 +15,6 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -122,17 +121,6 @@ class VertexAIAnthropicConfig:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
@ -189,29 +177,17 @@ def get_vertex_client(
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
access_token = _credentials.token
else:
vertex_ai_client = client
access_token = client.access_token
return vertex_ai_client, access_token
def create_vertex_anthropic_url(
vertex_location: str, vertex_project: str, model: str, stream: bool
) -> str:
if stream is True:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
def completion(
model: str,
messages: list,
@ -220,8 +196,6 @@ def completion(
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
@ -233,9 +207,6 @@ def completion(
try:
import vertexai
from anthropic import AnthropicVertex
from litellm.llms.anthropic import AnthropicChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except:
raise VertexAIError(
status_code=400,
@ -251,14 +222,13 @@ def completion(
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
vertex_ai_client, access_token = get_vertex_client(
client=client,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
)
anthropic_chat_completions = AnthropicChatCompletion()
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():