mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge branch 'BerriAI:main' into fix-anthropic-messages-api
This commit is contained in:
commit
38b5f34c77
366 changed files with 73092 additions and 56717 deletions
|
@ -298,7 +298,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -2,18 +2,13 @@ import os, types
|
|||
import json
|
||||
from enum import Enum
|
||||
import requests, copy
|
||||
import time, uuid
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -21,6 +16,8 @@ class AnthropicConstants(Enum):
|
|||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
|
||||
|
||||
|
||||
class AnthropicError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -37,12 +34,14 @@ class AnthropicError(Exception):
|
|||
|
||||
class AnthropicConfig:
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||
)
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
|
@ -52,7 +51,9 @@ class AnthropicConfig:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = 256, # anthropic requires a default
|
||||
max_tokens: Optional[
|
||||
int
|
||||
] = 4096, # You can pass in a value yourself or use the default value 4096
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
|
@ -101,124 +102,23 @@ def validate_environment(api_key, user_headers):
|
|||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
json_schemas: dict = {}
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(status_code=400, message=str(e))
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["system"] = (
|
||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
optional_params.pop("tools")
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream is not None and stream == True and _is_function_call == False
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose(f"makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
class AnthropicChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model,
|
||||
response,
|
||||
model_response,
|
||||
_is_function_call,
|
||||
stream,
|
||||
logging_obj,
|
||||
api_key,
|
||||
data,
|
||||
messages,
|
||||
print_verbose,
|
||||
):
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -245,46 +145,40 @@ def completion(
|
|||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
text_content = completion_response["content"][0].get("text", None)
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||
0
|
||||
].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
text_content # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||
if _is_function_call == True and stream is not None and stream == True:
|
||||
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
if _is_function_call and stream:
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
|
@ -318,7 +212,7 @@ def completion(
|
|||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -337,11 +231,278 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data), stream=True
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def acompletion_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(status_code=400, message=str(e))
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
headers["anthropic-beta"] = "tools-2024-04-04"
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
|
||||
optional_params["tools"] = anthropic_tools
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion == True:
|
||||
if (
|
||||
stream and not _is_function_call
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream and not _is_function_call
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.iter_lines()
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
|
@ -367,8 +528,3 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -4,10 +4,12 @@ from enum import Enum
|
|||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
|
@ -94,91 +96,13 @@ def validate_environment(api_key, user_headers):
|
|||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
class AnthropicTextCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
def process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
|
@ -213,10 +137,208 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
encoding,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["prompt"],
|
||||
api_key=headers.get("x-api-key"),
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
data: Optional[dict],
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
acompletion: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if acompletion == True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
# stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
elif acompletion == True:
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -799,6 +799,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
optional_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
|
@ -817,8 +818,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
"timeout": timeout,
|
||||
}
|
||||
|
||||
max_retries = optional_params.pop("max_retries", None)
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from litellm import OpenAIConfig
|
||||
|
@ -15,11 +16,11 @@ import litellm, json
|
|||
import httpx
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
from ..llms.openai import OpenAITextCompletion
|
||||
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
|
||||
import uuid
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
openai_text_completion = OpenAITextCompletion()
|
||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -300,9 +301,11 @@ class AzureTextCompletion(BaseLLM):
|
|||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return openai_text_completion.convert_to_model_response_object(
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
return (
|
||||
openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=TextCompletionResponse(**stringified_response),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
|
@ -373,7 +376,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
||||
return openai_text_completion.convert_to_model_response_object(
|
||||
return openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=response.model_dump(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
|
|
|
@ -55,9 +55,11 @@ def completion(
|
|||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
@ -71,9 +73,11 @@ def completion(
|
|||
completion_url_fragment_1 + model + completion_url_fragment_2,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
stream=(
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
)
|
||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
|
@ -102,28 +106,28 @@ def completion(
|
|||
and "data" in completion_response["model_output"]
|
||||
and isinstance(completion_response["model_output"]["data"], list)
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]["data"][0]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["model_output"]["data"][0]
|
||||
)
|
||||
elif isinstance(completion_response["model_output"], str):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["model_output"]
|
||||
)
|
||||
elif "completion" in completion_response and isinstance(
|
||||
completion_response["completion"], str
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["completion"]
|
||||
)
|
||||
elif isinstance(completion_response, list) and len(completion_response) > 0:
|
||||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response[0]["generated_text"]
|
||||
)
|
||||
## GETTING LOGPROBS
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
|
@ -155,7 +159,8 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -653,6 +653,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
@ -746,7 +750,7 @@ def completion(
|
|||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
)
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
|
@ -1008,7 +1012,7 @@ def completion(
|
|||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = model_response_iterator(
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
|
@ -1028,7 +1032,7 @@ def completion(
|
|||
total_tokens=response_body["usage"]["input_tokens"]
|
||||
+ response_body["usage"]["output_tokens"],
|
||||
)
|
||||
model_response.usage = _usage
|
||||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
|
@ -1071,8 +1075,10 @@ def completion(
|
|||
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
if getattr(model_response.usage, "total_tokens", None) is None:
|
||||
## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
|
||||
if not hasattr(model_response, "usage"):
|
||||
setattr(model_response, "usage", Usage())
|
||||
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
|
||||
prompt_tokens = response_metadata.get(
|
||||
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
||||
)
|
||||
|
@ -1089,7 +1095,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
|
@ -1109,8 +1115,30 @@ def completion(
|
|||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
async def model_response_iterator(model_response):
|
||||
yield model_response
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
def _embedding_func_single(
|
||||
|
|
|
@ -167,7 +167,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -237,7 +237,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ class CohereChatConfig:
|
|||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
|
@ -62,6 +63,7 @@ class CohereChatConfig:
|
|||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -82,6 +84,7 @@ class CohereChatConfig:
|
|||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -302,5 +305,5 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
|
96
litellm/llms/custom_httpx/http_handler.py
Normal file
96
litellm/llms/custom_httpx/http_handler.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import httpx, asyncio
|
||||
from typing import Optional, Union, Mapping, Any
|
||||
|
||||
# https://www.python-httpx.org/advanced/timeouts
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
|
||||
|
||||
class AsyncHTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.client
|
||||
|
||||
async def __aexit__(self):
|
||||
# close the client when exiting
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = await self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self.close())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
|
||||
def close(self):
|
||||
# Close the client when you're done with it
|
||||
self.client.close()
|
||||
|
||||
def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
|
@ -6,7 +6,8 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||
import litellm
|
||||
import sys, httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
class GeminiError(Exception):
|
||||
|
@ -103,6 +104,13 @@ class TextStreamer:
|
|||
break
|
||||
|
||||
|
||||
def supports_system_instruction():
|
||||
import google.generativeai as genai
|
||||
|
||||
gemini_pkg_version = Version(genai.__version__)
|
||||
return gemini_pkg_version >= Version("0.5.0")
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -124,7 +132,7 @@ def completion(
|
|||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
system_prompt = ""
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
|
@ -135,6 +143,7 @@ def completion(
|
|||
messages=messages,
|
||||
)
|
||||
else:
|
||||
system_prompt, messages = get_system_prompt(messages=messages)
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="gemini"
|
||||
)
|
||||
|
@ -162,11 +171,20 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
additional_args={
|
||||
"complete_input_dict": {
|
||||
"inference_params": inference_params,
|
||||
"system_prompt": system_prompt,
|
||||
}
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
_params = {"model_name": "models/{}".format(model)}
|
||||
_system_instruction = supports_system_instruction()
|
||||
if _system_instruction and len(system_prompt) > 0:
|
||||
_params["system_instruction"] = system_prompt
|
||||
_model = genai.GenerativeModel(**_params)
|
||||
if stream == True:
|
||||
if acompletion == True:
|
||||
|
||||
|
@ -213,11 +231,12 @@ def completion(
|
|||
encoding=encoding,
|
||||
)
|
||||
else:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
params = {
|
||||
"contents": prompt,
|
||||
"generation_config": genai.types.GenerationConfig(**inference_params),
|
||||
"safety_settings": safety_settings,
|
||||
}
|
||||
response = _model.generate_content(**params)
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -292,7 +311,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -152,9 +152,9 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["answer"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["answer"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["answer"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise MaritalkError(
|
||||
message=response.text, status_code=response.status_code
|
||||
|
@ -174,7 +174,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -185,9 +185,9 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["generated_text"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["generated_text"]
|
||||
)
|
||||
except:
|
||||
raise NLPCloudError(
|
||||
message=json.dumps(completion_response),
|
||||
|
@ -205,7 +205,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class OllamaError(Exception):
|
|||
|
||||
class OllamaConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
@ -69,7 +69,7 @@ class OllamaConfig:
|
|||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[list] = (
|
||||
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
|
||||
)
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
|
@ -228,8 +228,8 @@ def get_ollama_response(
|
|||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
@ -330,8 +330,8 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data["model"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
|
@ -20,7 +20,7 @@ class OllamaError(Exception):
|
|||
|
||||
class OllamaChatConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
@ -69,7 +69,7 @@ class OllamaChatConfig:
|
|||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[list] = (
|
||||
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
|
||||
)
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
|
@ -148,7 +148,7 @@ class OllamaChatConfig:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["repeat_penalty"] = param
|
||||
optional_params["repeat_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "response_format" and value["type"] == "json_object":
|
||||
|
@ -184,6 +184,7 @@ class OllamaChatConfig:
|
|||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
api_key: Optional[str] = None,
|
||||
model="llama2",
|
||||
messages=None,
|
||||
optional_params=None,
|
||||
|
@ -236,6 +237,7 @@ def get_ollama_response(
|
|||
if stream == True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -244,6 +246,7 @@ def get_ollama_response(
|
|||
else:
|
||||
response = ollama_acompletion(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -252,12 +255,17 @@ def get_ollama_response(
|
|||
)
|
||||
return response
|
||||
elif stream == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
return ollama_completion_stream(
|
||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
response = requests.post(**_request) # type: ignore
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
@ -307,10 +315,16 @@ def get_ollama_response(
|
|||
return model_response
|
||||
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
with httpx.stream(**_request) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
|
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
raise e
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
async def ollama_async_streaming(
|
||||
url, api_key, data, model_response, encoding, logging_obj
|
||||
):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
async with client.stream(**_request) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
|
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
|
||||
|
||||
async def ollama_acompletion(
|
||||
url, data, model_response, encoding, logging_obj, function_name
|
||||
url,
|
||||
api_key: Optional[str],
|
||||
data,
|
||||
model_response,
|
||||
encoding,
|
||||
logging_obj,
|
||||
function_name,
|
||||
):
|
||||
data["stream"] = False
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
resp = await session.post(url, json=data)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
resp = await session.post(**_request)
|
||||
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
|
|
|
@ -99,9 +99,9 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["choices"][0]["message"]["content"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["choices"][0]["message"]["content"]
|
||||
)
|
||||
except:
|
||||
raise OobaboogaError(
|
||||
message=json.dumps(completion_response),
|
||||
|
@ -115,7 +115,7 @@ def completion(
|
|||
completion_tokens=completion_response["usage"]["completion_tokens"],
|
||||
total_tokens=completion_response["usage"]["total_tokens"],
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from litellm.utils import (
|
|||
convert_to_model_response_object,
|
||||
Usage,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from typing import Callable, Optional
|
||||
import aiohttp, requests
|
||||
|
@ -200,6 +201,43 @@ class OpenAITextCompletionConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def convert_to_chat_model_response_object(
|
||||
self,
|
||||
response_object: Optional[TextCompletionResponse] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["text"],
|
||||
role="assistant",
|
||||
)
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"], index=idx, message=message
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
setattr(model_response_object, "usage", response_object["usage"])
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params["original_response"] = (
|
||||
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
)
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class OpenAIChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
|
@ -785,10 +823,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
optional_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
atranscription: bool = False,
|
||||
):
|
||||
|
@ -962,40 +1000,6 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def convert_to_model_response_object(
|
||||
self,
|
||||
response_object: Optional[dict] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(content=choice["text"], role="assistant")
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"], index=idx, message=message
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
model_response_object.usage = response_object["usage"]
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params["original_response"] = (
|
||||
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
)
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
|
@ -1010,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
super().completion()
|
||||
|
@ -1020,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
|
|||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
api_base = f"{api_base}/completions"
|
||||
|
||||
if (
|
||||
len(messages) > 0
|
||||
and "content" in messages[0]
|
||||
|
@ -1029,12 +1033,12 @@ class OpenAITextCompletion(BaseLLM):
|
|||
):
|
||||
prompt = messages[0]["content"]
|
||||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
prompt = [message["content"] for message in messages] # type: ignore
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
@ -1050,38 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
|
|||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
response = httpx.post(
|
||||
url=f"{api_base}", json=data, headers=headers, timeout=timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
response = openai_client.completions.create(**data) # type: ignore
|
||||
|
||||
response_json = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
|
@ -1089,10 +1108,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(
|
||||
response_object=response.json(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
return TextCompletionResponse(**response_json)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1107,101 +1123,112 @@ class OpenAITextCompletion(BaseLLM):
|
|||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries=None,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=litellm.request_timeout,
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(
|
||||
response_object=response_json, model_response_object=model_response
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
response = await openai_aclient.completions.create(**data)
|
||||
response_json = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response_obj = TextCompletionResponse(**response_json)
|
||||
response_obj._hidden_params.original_response = json.dumps(response_json)
|
||||
return response_obj
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
max_retries=None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
with httpx.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
else:
|
||||
openai_client = client
|
||||
response = openai_client.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
for chunk in streamwrapper:
|
||||
yield chunk
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
organization=None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response.aiter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
||||
response = await openai_client.completions.create(**data)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
|
|
@ -191,7 +191,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@ class PetalsConfig:
|
|||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # petals requires max tokens to be set
|
||||
max_new_tokens: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
|
@ -203,7 +203,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from enum import Enum
|
||||
import requests, traceback
|
||||
import json, re, xml.etree.ElementTree as ET
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from typing import Optional, Any
|
||||
import imghdr, base64
|
||||
from typing import List
|
||||
import litellm
|
||||
|
||||
|
@ -62,7 +62,7 @@ def llama_2_chat_pt(messages):
|
|||
|
||||
def ollama_pt(
|
||||
model, messages
|
||||
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
if "instruct" in model:
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
|
@ -145,6 +145,12 @@ def mistral_api_pt(messages):
|
|||
elif isinstance(m["content"], str):
|
||||
texts = m["content"]
|
||||
new_m = {"role": m["role"], "content": texts}
|
||||
|
||||
if new_m["role"] == "tool" and m.get("name"):
|
||||
new_m["name"] = m["name"]
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m["tool_calls"]
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
|
@ -218,7 +224,36 @@ def phind_codellama_pt(messages):
|
|||
return prompt
|
||||
|
||||
|
||||
known_tokenizer_config = {
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
||||
"bos_token": "<|begin_of_text|>",
|
||||
"eos_token": "",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
|
||||
# Define Jinja2 environment
|
||||
env = ImmutableSandboxedEnvironment()
|
||||
|
||||
def raise_exception(message):
|
||||
raise Exception(f"Error message - {message}")
|
||||
|
||||
# Create a template object from the template text
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
|
||||
## get the tokenizer config from huggingface
|
||||
bos_token = ""
|
||||
eos_token = ""
|
||||
|
@ -237,26 +272,23 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
|||
else:
|
||||
return {"status": "failure"}
|
||||
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if model in known_tokenizer_config:
|
||||
tokenizer_config = known_tokenizer_config[model]
|
||||
else:
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if (
|
||||
tokenizer_config["status"] == "failure"
|
||||
or "chat_template" not in tokenizer_config["tokenizer"]
|
||||
):
|
||||
raise Exception("No chat template found")
|
||||
## read the bos token, eos token and chat template from the json
|
||||
tokenizer_config = tokenizer_config["tokenizer"]
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
chat_template = tokenizer_config["chat_template"]
|
||||
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
|
||||
|
||||
def raise_exception(message):
|
||||
raise Exception(f"Error message - {message}")
|
||||
|
||||
# Create a template object from the template text
|
||||
env = Environment()
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
bos_token = tokenizer_config["bos_token"] # type: ignore
|
||||
eos_token = tokenizer_config["eos_token"] # type: ignore
|
||||
chat_template = tokenizer_config["chat_template"] # type: ignore
|
||||
try:
|
||||
template = env.from_string(chat_template)
|
||||
template = env.from_string(chat_template) # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -463,10 +495,11 @@ def construct_tool_use_system_prompt(
|
|||
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
|
||||
tool_str_list = []
|
||||
for tool in tools:
|
||||
tool_function = get_attribute_or_key(tool, "function")
|
||||
tool_str = construct_format_tool_for_claude_prompt(
|
||||
tool["function"]["name"],
|
||||
tool["function"].get("description", ""),
|
||||
tool["function"].get("parameters", {}),
|
||||
get_attribute_or_key(tool_function, "name"),
|
||||
get_attribute_or_key(tool_function, "description", ""),
|
||||
get_attribute_or_key(tool_function, "parameters", {}),
|
||||
)
|
||||
tool_str_list.append(tool_str)
|
||||
tool_use_system_prompt = (
|
||||
|
@ -556,7 +589,9 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
|
|||
)
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_result(message: dict) -> str:
|
||||
# The following XML functions will be deprecated once JSON schema support is available on Bedrock and Vertex
|
||||
# ------------------------------------------------------------------------------
|
||||
def convert_to_anthropic_tool_result_xml(message: dict) -> str:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
|
@ -588,7 +623,8 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
|
|||
</function_results>
|
||||
"""
|
||||
name = message.get("name")
|
||||
content = message.get("content")
|
||||
content = message.get("content", "")
|
||||
content = content.replace("<", "<").replace(">", ">").replace("&", "&")
|
||||
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
|
@ -606,16 +642,18 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
|
|||
return anthropic_tool_result
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
|
||||
def convert_to_anthropic_tool_invoke_xml(tool_calls: list) -> str:
|
||||
invokes = ""
|
||||
for tool in tool_calls:
|
||||
if tool["type"] != "function":
|
||||
if get_attribute_or_key(tool, "type") != "function":
|
||||
continue
|
||||
|
||||
tool_name = tool["function"]["name"]
|
||||
tool_function = get_attribute_or_key(tool, "function")
|
||||
tool_name = get_attribute_or_key(tool_function, "name")
|
||||
tool_arguments = get_attribute_or_key(tool_function, "arguments")
|
||||
parameters = "".join(
|
||||
f"<{param}>{val}</{param}>\n"
|
||||
for param, val in json.loads(tool["function"]["arguments"]).items()
|
||||
for param, val in json.loads(tool_arguments).items()
|
||||
)
|
||||
invokes += (
|
||||
"<invoke>\n"
|
||||
|
@ -631,7 +669,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
|
|||
return anthropic_tool_invoke
|
||||
|
||||
|
||||
def anthropic_messages_pt(messages: list):
|
||||
def anthropic_messages_pt_xml(messages: list):
|
||||
"""
|
||||
format messages for anthropic
|
||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||
|
@ -669,7 +707,7 @@ def anthropic_messages_pt(messages: list):
|
|||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
convert_to_anthropic_tool_result(messages[msg_i])
|
||||
convert_to_anthropic_tool_result_xml(messages[msg_i])
|
||||
if messages[msg_i]["role"] == "tool"
|
||||
else messages[msg_i]["content"]
|
||||
),
|
||||
|
@ -696,7 +734,7 @@ def anthropic_messages_pt(messages: list):
|
|||
if messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_text += convert_to_anthropic_tool_invoke(
|
||||
assistant_text += convert_to_anthropic_tool_invoke_xml( # type: ignore
|
||||
messages[msg_i]["tool_calls"]
|
||||
)
|
||||
|
||||
|
@ -706,7 +744,192 @@ def anthropic_messages_pt(messages: list):
|
|||
if assistant_content:
|
||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if new_messages[0]["role"] != "user":
|
||||
if not new_messages or new_messages[0]["role"] != "user":
|
||||
if litellm.modify_params:
|
||||
new_messages.insert(
|
||||
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Invalid first message. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, "
|
||||
)
|
||||
|
||||
if new_messages[-1]["role"] == "assistant":
|
||||
for content in new_messages[-1]["content"]:
|
||||
if isinstance(content, dict) and content["type"] == "text":
|
||||
content["text"] = content[
|
||||
"text"
|
||||
].rstrip() # no trailing whitespace for final assistant message
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
"tool_call_id": "tool_1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
},
|
||||
"""
|
||||
|
||||
"""
|
||||
Anthropic tool_results look like:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
|
||||
"content": "ConnectionError: the weather service API is not available (HTTP 500)",
|
||||
# "is_error": true
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
content = message.get("content")
|
||||
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
anthropic_tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return anthropic_tool_result
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||
"""
|
||||
OpenAI tool invokes:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"""
|
||||
|
||||
"""
|
||||
Anthropic tool invokes:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "<thinking>To answer this question, I will: 1. Use the get_weather tool to get the current weather in San Francisco. 2. Use the get_time tool to get the current time in the America/Los_Angeles timezone, which covers San Francisco, CA.</thinking>"
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_01A09q90qw90lq917835lq9",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco, CA"}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
anthropic_tool_invoke = [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": get_attribute_or_key(tool, "id"),
|
||||
"name": get_attribute_or_key(
|
||||
get_attribute_or_key(tool, "function"), "name"
|
||||
),
|
||||
"input": json.loads(
|
||||
get_attribute_or_key(
|
||||
get_attribute_or_key(tool, "function"), "arguments"
|
||||
)
|
||||
),
|
||||
}
|
||||
for tool in tool_calls
|
||||
if get_attribute_or_key(tool, "type") == "function"
|
||||
]
|
||||
|
||||
return anthropic_tool_invoke
|
||||
|
||||
|
||||
def anthropic_messages_pt(messages: list):
|
||||
"""
|
||||
format messages for anthropic
|
||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||
2. The first message always needs to be of role "user"
|
||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||
5. System messages are a separate param to the Messages API
|
||||
6. Ensure we only accept role, content. (message.name is not supported)
|
||||
"""
|
||||
# add role=tool support to allow function call result/error submission
|
||||
user_message_types = {"user", "tool"}
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
||||
new_messages = []
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content = []
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
for m in messages[msg_i]["content"]:
|
||||
if m.get("type", "") == "image_url":
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": convert_to_anthropic_image_obj(
|
||||
m["image_url"]["url"]
|
||||
),
|
||||
}
|
||||
)
|
||||
elif m.get("type", "") == "text":
|
||||
user_content.append({"type": "text", "text": m["text"]})
|
||||
elif messages[msg_i]["role"] == "tool":
|
||||
# OpenAI's tool message content will always be a string
|
||||
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
||||
else:
|
||||
user_content.append(
|
||||
{"type": "text", "text": messages[msg_i]["content"]}
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
new_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append({"type": "text", "text": assistant_text})
|
||||
|
||||
if messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_content.extend(
|
||||
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if not new_messages or new_messages[0]["role"] != "user":
|
||||
if litellm.modify_params:
|
||||
new_messages.insert(
|
||||
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
|
||||
|
@ -784,7 +1007,20 @@ def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
|
|||
return params
|
||||
|
||||
|
||||
###
|
||||
### GEMINI HELPER FUNCTIONS ###
|
||||
|
||||
|
||||
def get_system_prompt(messages):
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
return system_prompt, messages
|
||||
|
||||
|
||||
def convert_openai_message_to_cohere_tool_result(message):
|
||||
|
@ -842,7 +1078,8 @@ def cohere_message_pt(messages: list):
|
|||
tool_result = convert_openai_message_to_cohere_tool_result(message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
prompt += message["content"]
|
||||
prompt += message["content"] + "\n\n"
|
||||
prompt = prompt.rstrip()
|
||||
return prompt, tool_results
|
||||
|
||||
|
||||
|
@ -916,12 +1153,6 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
Returns:
|
||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
|
||||
try:
|
||||
# given messages for gpt-4 vision, convert them for gemini
|
||||
|
@ -948,6 +1179,12 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
image = _load_image_from_url(img)
|
||||
processed_images.append(image)
|
||||
else:
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
# Case 2: Image filepath (e.g. temp.jpeg) given
|
||||
image = Image.open(img)
|
||||
processed_images.append(image)
|
||||
|
@ -1087,13 +1324,19 @@ def prompt_factory(
|
|||
if model == "claude-instant-1" or model == "claude-2":
|
||||
return anthropic_pt(messages=messages)
|
||||
return anthropic_messages_pt(messages=messages)
|
||||
elif custom_llm_provider == "anthropic_xml":
|
||||
return anthropic_messages_pt_xml(messages=messages)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(
|
||||
messages=messages, prompt_format=prompt_format, chat_template=chat_template
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
if model == "gemini-pro-vision":
|
||||
if (
|
||||
model == "gemini-pro-vision"
|
||||
or litellm.supports_vision(model=model)
|
||||
or litellm.supports_vision(model=custom_llm_provider + "/" + model)
|
||||
):
|
||||
return _gemini_vision_convert_messages(messages=messages)
|
||||
else:
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
|
@ -1109,6 +1352,13 @@ def prompt_factory(
|
|||
return anthropic_pt(messages=messages)
|
||||
elif "mistral." in model:
|
||||
return mistral_instruct_pt(messages=messages)
|
||||
elif "llama2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif "llama3" in model and "instruct" in model:
|
||||
return hf_chat_template(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
messages=messages,
|
||||
)
|
||||
elif custom_llm_provider == "perplexity":
|
||||
for message in messages:
|
||||
message.pop("name", None)
|
||||
|
@ -1118,6 +1368,13 @@ def prompt_factory(
|
|||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif (
|
||||
"meta-llama/llama-3" in model or "meta-llama-3" in model
|
||||
) and "instruct" in model:
|
||||
return hf_chat_template(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
messages=messages,
|
||||
)
|
||||
elif (
|
||||
"tiiuae/falcon" in model
|
||||
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
|
@ -1158,3 +1415,9 @@ def prompt_factory(
|
|||
return default_pt(
|
||||
messages=messages
|
||||
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
||||
|
||||
def get_attribute_or_key(tool_or_function, attribute, default=None):
|
||||
if hasattr(tool_or_function, attribute):
|
||||
return getattr(tool_or_function, attribute)
|
||||
return tool_or_function.get(attribute, default)
|
||||
|
|
|
@ -112,10 +112,16 @@ def start_prediction(
|
|||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
"version": version_id,
|
||||
"input": input_data,
|
||||
}
|
||||
|
||||
if ":" in version_id and len(version_id) > 64:
|
||||
model_parts = version_id.split(":")
|
||||
if (
|
||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
initial_prediction_data["version"] = model_parts[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
|
@ -307,9 +313,7 @@ def completion(
|
|||
result, logs = handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
model_response["ended"] = (
|
||||
time.time()
|
||||
) # for pricing this must remain right after calling api
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -332,9 +336,12 @@ def completion(
|
|||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
|
@ -342,7 +349,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -365,7 +365,10 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
completion_response = json.loads(response)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
if isinstance(completion_response, list):
|
||||
completion_response_choices = completion_response[0]
|
||||
else:
|
||||
completion_response_choices = completion_response
|
||||
completion_output = ""
|
||||
if "generation" in completion_response_choices:
|
||||
completion_output += completion_response_choices["generation"]
|
||||
|
@ -396,7 +399,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
@ -580,7 +583,10 @@ async def async_completion(
|
|||
## RESPONSE OBJECT
|
||||
completion_response = json.loads(response)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
if isinstance(completion_response, list):
|
||||
completion_response_choices = completion_response[0]
|
||||
else:
|
||||
completion_response_choices = completion_response
|
||||
completion_output = ""
|
||||
if "generation" in completion_response_choices:
|
||||
completion_output += completion_response_choices["generation"]
|
||||
|
@ -611,7 +617,7 @@ async def async_completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Deprecated. We now do together ai calls via the openai client.
|
||||
Reference: https://docs.together.ai/docs/openai-api-compatibility
|
||||
"""
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -225,7 +226,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
50280
litellm/llms/tokenizers/ec7223a39ce59f226a68acc30dc1af2788490e15
Normal file
50280
litellm/llms/tokenizers/ec7223a39ce59f226a68acc30dc1af2788490e15
Normal file
File diff suppressed because it is too large
Load diff
|
@ -3,10 +3,10 @@ import json
|
|||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Optional, Union, List
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx
|
||||
import httpx, inspect
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -22,9 +22,39 @@ class VertexAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class ExtendedGenerationConfig(dict):
|
||||
"""Extended parameters for the generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
):
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
candidate_count=candidate_count,
|
||||
max_output_tokens=max_output_tokens,
|
||||
stop_sequences=stop_sequences,
|
||||
response_mime_type=response_mime_type,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIConfig:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
|
||||
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||
|
||||
|
@ -36,6 +66,16 @@ class VertexAIConfig:
|
|||
|
||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||
|
||||
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||
|
||||
- `candidate_count` (int): Number of generated responses to return.
|
||||
|
||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||
|
||||
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||
|
||||
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
|
@ -43,6 +83,11 @@ class VertexAIConfig:
|
|||
max_output_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
response_mime_type: Optional[str] = None
|
||||
candidate_count: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -50,6 +95,11 @@ class VertexAIConfig:
|
|||
max_output_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -74,6 +124,66 @@ class VertexAIConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if (
|
||||
param == "stream" and value == True
|
||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||
optional_params["stream"] = value
|
||||
if param == "n":
|
||||
optional_params["candidate_count"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
optional_params["stop_sequences"] = [value]
|
||||
elif isinstance(value, list):
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_output_tokens"] = value
|
||||
if param == "response_format" and value["type"] == "json_object":
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
from vertexai.preview import generative_models
|
||||
|
||||
gtool_func_declarations = []
|
||||
for tool in value:
|
||||
gtool_func_declaration = generative_models.FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
generative_models.Tool(
|
||||
function_declarations=gtool_func_declarations
|
||||
)
|
||||
]
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
pass
|
||||
return optional_params
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
|
@ -117,8 +227,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
|
|||
image_bytes = response.content
|
||||
return image_bytes
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle any request exceptions (e.g., connection error, timeout)
|
||||
return b"" # Return an empty bytes object or handle the error as needed
|
||||
raise Exception(f"An exception occurs with this image - {str(e)}")
|
||||
|
||||
|
||||
def _load_image_from_url(image_url: str):
|
||||
|
@ -139,7 +248,8 @@ def _load_image_from_url(image_url: str):
|
|||
)
|
||||
|
||||
image_bytes = _get_image_bytes_from_url(image_url)
|
||||
return Image.from_bytes(image_bytes)
|
||||
|
||||
return Image.from_bytes(data=image_bytes)
|
||||
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
|
@ -257,6 +367,7 @@ def completion(
|
|||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -299,7 +410,17 @@ def completion(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
|
@ -417,13 +538,14 @@ def completion(
|
|||
return async_completion(**data)
|
||||
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
stream = optional_params.pop("stream", False)
|
||||
if stream == True:
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -436,12 +558,12 @@ def completion(
|
|||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
generation_config=optional_params,
|
||||
safety_settings=safety_settings,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
|
||||
return model_response
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
|
@ -458,7 +580,7 @@ def completion(
|
|||
## LLM Call
|
||||
response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
generation_config=optional_params,
|
||||
safety_settings=safety_settings,
|
||||
tools=tools,
|
||||
)
|
||||
|
@ -513,7 +635,7 @@ def completion(
|
|||
},
|
||||
)
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
|
||||
return model_response
|
||||
|
||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||
|
@ -545,7 +667,7 @@ def completion(
|
|||
},
|
||||
)
|
||||
model_response = llm_model.predict_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
|
||||
return model_response
|
||||
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
|
@ -670,7 +792,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
@ -697,12 +819,11 @@ async def async_completion(
|
|||
Add support for acompletion calls for gemini-pro
|
||||
"""
|
||||
try:
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
|
@ -719,14 +840,15 @@ async def async_completion(
|
|||
)
|
||||
|
||||
## LLM Call
|
||||
# print(f"final content: {content}")
|
||||
response = await llm_model._generate_content_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
generation_config=optional_params,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if tools is not None and hasattr(
|
||||
response.candidates[0].content.parts[0], "function_call"
|
||||
if tools is not None and bool(
|
||||
getattr(response.candidates[0].content.parts[0], "function_call", None)
|
||||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
|
@ -879,7 +1001,7 @@ async def async_completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
@ -905,8 +1027,6 @@ async def async_streaming(
|
|||
"""
|
||||
Add support for async streaming calls for gemini-pro
|
||||
"""
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
@ -927,11 +1047,10 @@ async def async_streaming(
|
|||
|
||||
response = await llm_model._generate_content_streaming_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
generation_config=optional_params,
|
||||
tools=tools,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
optional_params["tools"] = tools
|
||||
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
optional_params.pop(
|
||||
|
@ -950,7 +1069,7 @@ async def async_streaming(
|
|||
},
|
||||
)
|
||||
response = chat.send_message_streaming_async(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
|
||||
elif mode == "text":
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
|
@ -1046,6 +1165,7 @@ def embedding(
|
|||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
print_verbose=None,
|
||||
):
|
||||
|
@ -1066,7 +1186,17 @@ def embedding(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
|
|
469
litellm/llms/vertex_ai_anthropic.py
Normal file
469
litellm/llms/vertex_ai_anthropic.py
Normal file
|
@ -0,0 +1,469 @@
|
|||
# What is this?
|
||||
## Handler file for calling claude-3 on vertex ai
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy
|
||||
import time, uuid
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .prompt_templates.factory import (
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAIAnthropicConfig:
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
||||
|
||||
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
||||
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
||||
|
||||
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
||||
)
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
"""
|
||||
- Run client init
|
||||
- Support async completion, streaming
|
||||
"""
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
def refresh_auth(
|
||||
credentials,
|
||||
) -> str: # used when user passes in credentials as json string
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
if credentials.token is None:
|
||||
credentials.refresh(Request())
|
||||
|
||||
if not credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the credentials")
|
||||
|
||||
return credentials.token
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
from anthropic import AnthropicVertex
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## Format Prompt
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=400, message=str(e))
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["system"] = (
|
||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
optional_params.pop("tools")
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
|
||||
## Completion Call
|
||||
|
||||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||
)
|
||||
access_token = None
|
||||
if client is None:
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
### CHECK IF ACCESS
|
||||
access_token = refresh_auth(credentials=creds)
|
||||
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project,
|
||||
region=vertex_location,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
if acompletion == True:
|
||||
"""
|
||||
- async streaming
|
||||
- async completion
|
||||
"""
|
||||
if stream is not None and stream == True:
|
||||
return async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
print_verbose=print_verbose,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
return async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
print_verbose=print_verbose,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
access_token=access_token,
|
||||
)
|
||||
if stream is not None and stream == True:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
||||
return response
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
|
||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||
0
|
||||
].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = message.usage.input_tokens
|
||||
completion_tokens = message.usage.output_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
message = await vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = message.usage.input_tokens
|
||||
completion_tokens = message.usage.output_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_streaming(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
||||
logging_obj.post_call(input=messages, api_key=None, original_response=response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
|
@ -104,7 +104,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
@ -186,7 +186,7 @@ def batch_completions(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
final_outputs.append(model_response)
|
||||
return final_outputs
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue