forked from phoenix/litellm-mirror
fix: linting fixes
This commit is contained in:
parent
17869fc5e9
commit
19c982d0f9
5 changed files with 62 additions and 16 deletions
|
@ -426,6 +426,7 @@ class Logging:
|
|||
self.model_call_details["additional_args"] = additional_args
|
||||
self.model_call_details["log_event_type"] = "post_api_call"
|
||||
|
||||
if json_logs:
|
||||
verbose_logger.debug(
|
||||
"RAW RESPONSE:\n{}\n\n".format(
|
||||
self.model_call_details.get(
|
||||
|
@ -433,6 +434,14 @@ class Logging:
|
|||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
print_verbose(
|
||||
"RAW RESPONSE:\n{}\n\n".format(
|
||||
self.model_call_details.get(
|
||||
"original_response", self.model_call_details
|
||||
)
|
||||
)
|
||||
)
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
self.logger_fn(
|
||||
|
|
|
@ -601,13 +601,16 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
optional_params["tools"] = anthropic_tools
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
if is_vertex_request is False:
|
||||
data["model"] = model
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
|
|
@ -15,6 +15,7 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
||||
from litellm.types.utils import ResponseFormatChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
|
@ -121,6 +122,17 @@ class VertexAIAnthropicConfig:
|
|||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||
if value == "auto":
|
||||
_tool_choice = {"type": "auto"}
|
||||
elif value == "required":
|
||||
_tool_choice = {"type": "any"}
|
||||
elif isinstance(value, dict):
|
||||
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
||||
|
||||
if _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
|
@ -177,17 +189,29 @@ def get_vertex_client(
|
|||
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project or cred_project_id,
|
||||
region=vertex_location or "us-central1",
|
||||
access_token=_credentials.token,
|
||||
)
|
||||
access_token = _credentials.token
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
access_token = client.access_token
|
||||
|
||||
return vertex_ai_client, access_token
|
||||
|
||||
|
||||
def create_vertex_anthropic_url(
|
||||
vertex_location: str, vertex_project: str, model: str, stream: bool
|
||||
) -> str:
|
||||
if stream is True:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -196,6 +220,8 @@ def completion(
|
|||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
|
@ -207,6 +233,9 @@ def completion(
|
|||
try:
|
||||
import vertexai
|
||||
from anthropic import AnthropicVertex
|
||||
|
||||
from litellm.llms.anthropic import AnthropicChatCompletion
|
||||
from litellm.llms.vertex_httpx import VertexLLM
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
|
@ -222,13 +251,14 @@ def completion(
|
|||
)
|
||||
try:
|
||||
|
||||
vertex_ai_client, access_token = get_vertex_client(
|
||||
client=client,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
|
|
@ -2008,6 +2008,8 @@ def completion(
|
|||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
else:
|
||||
model_response = vertex_ai.completion(
|
||||
|
|
|
@ -640,11 +640,13 @@ def test_gemini_pro_vision_base64():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]
|
||||
) # "vertex_ai",
|
||||
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
||||
async def test_gemini_pro_function_calling_httpx(model, sync_mode):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
@ -682,7 +684,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
|||
]
|
||||
|
||||
data = {
|
||||
"model": "{}/gemini-1.5-pro".format(provider),
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "required",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue