Merge branch 'BerriAI:main' into fix-anthropic-messages-api

This commit is contained in:
Emir Ayar 2024-04-27 11:50:04 +02:00 committed by GitHub
commit 38b5f34c77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
366 changed files with 73092 additions and 56717 deletions

View file

@ -298,7 +298,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -2,18 +2,13 @@ import os, types
import json
from enum import Enum
import requests, copy
import time, uuid
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx
@ -21,6 +16,8 @@ class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
class AnthropicError(Exception):
def __init__(self, status_code, message):
@ -37,12 +34,14 @@ class AnthropicError(Exception):
class AnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
Reference: https://docs.anthropic.com/claude/reference/messages_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
@ -52,7 +51,9 @@ class AnthropicConfig:
def __init__(
self,
max_tokens: Optional[int] = 256, # anthropic requires a default
max_tokens: Optional[
int
] = 4096, # You can pass in a value yourself or use the default value 4096
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
@ -101,124 +102,23 @@ def validate_environment(api_key, user_headers):
return headers
def completion(
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL
if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
print_verbose,
):
## LOGGING
logging_obj.post_call(
input=messages,
@ -245,46 +145,40 @@ def completion(
status_code=response.status_code,
)
else:
text_content = completion_response["content"][0].get("text", None)
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": f"call_{uuid.uuid4()}",
"id": content["id"],
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
text_content # allow user to access raw anthropic tool calling response
)
else:
model_response.choices[0].message.content = text_content # type: ignore
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
@ -318,7 +212,7 @@ def completion(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -337,11 +231,278 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
return model_response
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data), stream=True
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streaming_response
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass
class ModelResponseIterator:
def __init__(self, model_response):
@ -367,8 +528,3 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,10 +4,12 @@ from enum import Enum
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
from .base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class AnthropicConstants(Enum):
@ -94,91 +96,13 @@ def validate_environment(api_key, user_headers):
return headers
def completion(
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
def process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
):
## RESPONSE OBJECT
try:
completion_response = response.json()
@ -213,10 +137,208 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
async def async_completion(
self,
model: str,
model_response: ModelResponse,
api_base: str,
logging_obj,
encoding,
headers: dict,
data: dict,
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key=headers.get("x-api-key"),
original_response=response.text,
additional_args={"complete_input_dict": data},
)
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
async def async_streaming(
self,
model: str,
api_base: str,
logging_obj,
headers: dict,
data: Optional[dict],
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return streamwrapper
def completion(
self,
model: str,
messages: list,
api_base: str,
acompletion: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion == True:
return self.async_streaming(
model=model,
api_base=api_base,
logging_obj=logging_obj,
headers=headers,
data=data,
client=None,
)
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
# stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return stream_response
elif acompletion == True:
return self.async_completion(
model=model,
model_response=model_response,
api_base=api_base,
logging_obj=logging_obj,
encoding=encoding,
headers=headers,
data=data,
client=client,
)
else:
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -799,6 +799,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
@ -817,8 +818,6 @@ class AzureChatCompletion(BaseLLM):
"timeout": timeout,
}
max_retries = optional_params.pop("max_retries", None)
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)

View file

@ -8,6 +8,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
@ -15,11 +16,11 @@ import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
from ..llms.openai import OpenAITextCompletion
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
import uuid
from .prompt_templates.factory import prompt_factory, custom_prompt
openai_text_completion = OpenAITextCompletion()
openai_text_completion_config = OpenAITextCompletionConfig()
class AzureOpenAIError(Exception):
@ -300,9 +301,11 @@ class AzureTextCompletion(BaseLLM):
"api_base": api_base,
},
)
return openai_text_completion.convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
return (
openai_text_completion_config.convert_to_chat_model_response_object(
response_object=TextCompletionResponse(**stringified_response),
model_response_object=model_response,
)
)
except AzureOpenAIError as e:
exception_mapping_worked = True
@ -373,7 +376,7 @@ class AzureTextCompletion(BaseLLM):
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
return openai_text_completion.convert_to_model_response_object(
return openai_text_completion_config.convert_to_chat_model_response_object(
response_object=response.model_dump(),
model_response_object=model_response,
)

View file

@ -55,9 +55,11 @@ def completion(
"inputs": prompt,
"prompt": prompt,
"parameters": optional_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params and optional_params["stream"] == True
else False
),
}
## LOGGING
@ -71,9 +73,11 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers,
data=json.dumps(data),
stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
stream=(
True
if "stream" in optional_params and optional_params["stream"] == True
else False
),
)
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
@ -102,28 +106,28 @@ def completion(
and "data" in completion_response["model_output"]
and isinstance(completion_response["model_output"]["data"], list)
):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]["data"][0]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]["data"][0]
)
elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]
)
elif "completion" in completion_response and isinstance(
completion_response["completion"], str
):
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code,
)
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS
if (
"details" in completion_response[0]
@ -155,7 +159,8 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -653,6 +653,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
@ -746,7 +750,7 @@ def completion(
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
model=model, messages=messages, custom_llm_provider="anthropic_xml"
)
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
@ -1008,7 +1012,7 @@ def completion(
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator(
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
@ -1028,7 +1032,7 @@ def completion(
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
setattr(model_response, "usage", _usage)
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
@ -1071,8 +1075,10 @@ def completion(
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
if getattr(model_response.usage, "total_tokens", None) is None:
## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
if not hasattr(model_response, "usage"):
setattr(model_response, "usage", Usage())
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
@ -1089,7 +1095,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
model_response["created"] = int(time.time())
model_response["model"] = model
@ -1109,8 +1115,30 @@ def completion(
raise BedrockError(status_code=500, message=traceback.format_exc())
async def model_response_iterator(model_response):
yield model_response
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def _embedding_func_single(

View file

@ -167,7 +167,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -237,7 +237,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -43,6 +43,7 @@ class CohereChatConfig:
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
@ -62,6 +63,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
@ -82,6 +84,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -302,5 +305,5 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -0,0 +1,96 @@
import httpx, asyncio
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def __aenter__(self):
return self.client
async def __aexit__(self):
# close the client when exiting
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
return response
def __del__(self) -> None:
try:
asyncio.get_running_loop().create_task(self.close())
except Exception:
pass
class HTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
def close(self):
# Close the client when you're done with it
self.client.close()
def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = self.client.get(url, params=params, headers=headers)
return response
def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = self.client.post(url, data=data, params=params, headers=headers)
return response
def __del__(self) -> None:
try:
self.close()
except Exception:
pass

View file

@ -6,7 +6,8 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
class GeminiError(Exception):
@ -103,6 +104,13 @@ class TextStreamer:
break
def supports_system_instruction():
import google.generativeai as genai
gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0")
def completion(
model: str,
messages: list,
@ -124,7 +132,7 @@ def completion(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)
system_prompt = ""
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -135,6 +143,7 @@ def completion(
messages=messages,
)
else:
system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
@ -162,11 +171,20 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
additional_args={
"complete_input_dict": {
"inference_params": inference_params,
"system_prompt": system_prompt,
}
},
)
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
_params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream == True:
if acompletion == True:
@ -213,11 +231,12 @@ def completion(
encoding=encoding,
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
params = {
"contents": prompt,
"generation_config": genai.types.GenerationConfig(**inference_params),
"safety_settings": safety_settings,
}
response = _model.generate_content(**params)
except Exception as e:
raise GeminiError(
message=str(e),
@ -292,7 +311,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -152,9 +152,9 @@ def completion(
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
model_response["choices"][0]["message"]["content"] = (
completion_response["answer"]
)
except Exception as e:
raise MaritalkError(
message=response.text, status_code=response.status_code
@ -174,7 +174,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -185,9 +185,9 @@ def completion(
else:
try:
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response["generated_text"]
)
except:
raise NLPCloudError(
message=json.dumps(completion_response),
@ -205,7 +205,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaConfig:
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
@ -228,8 +228,8 @@ def get_ollama_response(
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
completion_tokens = response_json["eval_count"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -330,8 +330,8 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
completion_tokens = response_json["eval_count"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaChatConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaChatConfig:
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
@ -148,7 +148,7 @@ class OllamaChatConfig:
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
optional_params["repeat_penalty"] = param
optional_params["repeat_penalty"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object":
@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
messages=None,
optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True:
response = ollama_async_streaming(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else:
response = ollama_acompletion(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
)
return response
elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
response = requests.post(
url=f"{url}",
json=data,
)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
def ollama_completion_stream(url, api_key, data, logging_obj):
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try:
if response.status_code != 200:
raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name
url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200:
text = await resp.text()

View file

@ -99,9 +99,9 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"][
"content"
] = completion_response["choices"][0]["message"]["content"]
model_response["choices"][0]["message"]["content"] = (
completion_response["choices"][0]["message"]["content"]
)
except:
raise OobaboogaError(
message=json.dumps(completion_response),
@ -115,7 +115,7 @@ def completion(
completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"],
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -10,6 +10,7 @@ from litellm.utils import (
convert_to_model_response_object,
Usage,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional
import aiohttp, requests
@ -200,6 +201,43 @@ class OpenAITextCompletionConfig:
and v is not None
}
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
@ -785,10 +823,10 @@ class OpenAIChatCompletion(BaseLLM):
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
atranscription: bool = False,
):
@ -962,40 +1000,6 @@ class OpenAITextCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def convert_to_model_response_object(
self,
response_object: Optional[dict] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
def completion(
self,
model_response: ModelResponse,
@ -1010,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
super().completion()
@ -1020,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if (
len(messages) > 0
and "content" in messages[0]
@ -1029,12 +1033,12 @@ class OpenAITextCompletion(BaseLLM):
):
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
prompt = [message["content"] for message in messages] # type: ignore
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
@ -1050,38 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
response = httpx.post(
url=f"{api_base}", json=data, headers=headers, timeout=timeout
)
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
@ -1089,10 +1108,7 @@ class OpenAITextCompletion(BaseLLM):
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(
response_object=response.json(),
model_response_object=model_response,
)
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
@ -1107,101 +1123,112 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str,
timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
):
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
## RESPONSE OBJECT
return self.convert_to_model_response_object(
response_object=response_json, model_response_object=model_response
)
except Exception as e:
raise e
response = await openai_aclient.completions.create(**data)
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
with httpx.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=timeout,
) as response:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
else:
openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield chunk
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=timeout,
) as response:
try:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
response = await openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -191,7 +191,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -41,9 +41,9 @@ class PetalsConfig:
"""
max_length: Optional[int] = None
max_new_tokens: Optional[
int
] = litellm.max_tokens # petals requires max tokens to be set
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -203,7 +203,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -1,9 +1,9 @@
from enum import Enum
import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, Environment, meta
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Optional, Any
import imghdr, base64
from typing import List
import litellm
@ -62,7 +62,7 @@ def llama_2_chat_pt(messages):
def ollama_pt(
model, messages
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt(
role_dict={
@ -145,6 +145,12 @@ def mistral_api_pt(messages):
elif isinstance(m["content"], str):
texts = m["content"]
new_m = {"role": m["role"], "content": texts}
if new_m["role"] == "tool" and m.get("name"):
new_m["name"] = m["name"]
if m.get("tool_calls"):
new_m["tool_calls"] = m["tool_calls"]
new_messages.append(new_m)
return new_messages
@ -218,7 +224,36 @@ def phind_codellama_pt(messages):
return prompt
known_tokenizer_config = {
"mistralai/Mistral-7B-Instruct-v0.1": {
"tokenizer": {
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"bos_token": "<s>",
"eos_token": "</s>",
},
"status": "success",
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"tokenizer": {
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
"bos_token": "<|begin_of_text|>",
"eos_token": "",
},
"status": "success",
},
}
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
# Define Jinja2 environment
env = ImmutableSandboxedEnvironment()
def raise_exception(message):
raise Exception(f"Error message - {message}")
# Create a template object from the template text
env.globals["raise_exception"] = raise_exception
## get the tokenizer config from huggingface
bos_token = ""
eos_token = ""
@ -237,26 +272,23 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if model in known_tokenizer_config:
tokenizer_config = known_tokenizer_config[model]
else:
tokenizer_config = _get_tokenizer_config(model)
if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
):
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
def raise_exception(message):
raise Exception(f"Error message - {message}")
# Create a template object from the template text
env = Environment()
env.globals["raise_exception"] = raise_exception
bos_token = tokenizer_config["bos_token"] # type: ignore
eos_token = tokenizer_config["eos_token"] # type: ignore
chat_template = tokenizer_config["chat_template"] # type: ignore
try:
template = env.from_string(chat_template)
template = env.from_string(chat_template) # type: ignore
except Exception as e:
raise e
@ -463,10 +495,11 @@ def construct_tool_use_system_prompt(
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
tool_str_list = []
for tool in tools:
tool_function = get_attribute_or_key(tool, "function")
tool_str = construct_format_tool_for_claude_prompt(
tool["function"]["name"],
tool["function"].get("description", ""),
tool["function"].get("parameters", {}),
get_attribute_or_key(tool_function, "name"),
get_attribute_or_key(tool_function, "description", ""),
get_attribute_or_key(tool_function, "parameters", {}),
)
tool_str_list.append(tool_str)
tool_use_system_prompt = (
@ -556,7 +589,9 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
)
def convert_to_anthropic_tool_result(message: dict) -> str:
# The following XML functions will be deprecated once JSON schema support is available on Bedrock and Vertex
# ------------------------------------------------------------------------------
def convert_to_anthropic_tool_result_xml(message: dict) -> str:
"""
OpenAI message with a tool result looks like:
{
@ -588,7 +623,8 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
</function_results>
"""
name = message.get("name")
content = message.get("content")
content = message.get("content", "")
content = content.replace("<", "&lt;").replace(">", "&gt;").replace("&", "&amp;")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
@ -606,16 +642,18 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
return anthropic_tool_result
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
def convert_to_anthropic_tool_invoke_xml(tool_calls: list) -> str:
invokes = ""
for tool in tool_calls:
if tool["type"] != "function":
if get_attribute_or_key(tool, "type") != "function":
continue
tool_name = tool["function"]["name"]
tool_function = get_attribute_or_key(tool, "function")
tool_name = get_attribute_or_key(tool_function, "name")
tool_arguments = get_attribute_or_key(tool_function, "arguments")
parameters = "".join(
f"<{param}>{val}</{param}>\n"
for param, val in json.loads(tool["function"]["arguments"]).items()
for param, val in json.loads(tool_arguments).items()
)
invokes += (
"<invoke>\n"
@ -631,7 +669,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
return anthropic_tool_invoke
def anthropic_messages_pt(messages: list):
def anthropic_messages_pt_xml(messages: list):
"""
format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
@ -669,7 +707,7 @@ def anthropic_messages_pt(messages: list):
{
"type": "text",
"text": (
convert_to_anthropic_tool_result(messages[msg_i])
convert_to_anthropic_tool_result_xml(messages[msg_i])
if messages[msg_i]["role"] == "tool"
else messages[msg_i]["content"]
),
@ -696,7 +734,7 @@ def anthropic_messages_pt(messages: list):
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_text += convert_to_anthropic_tool_invoke(
assistant_text += convert_to_anthropic_tool_invoke_xml( # type: ignore
messages[msg_i]["tool_calls"]
)
@ -706,7 +744,192 @@ def anthropic_messages_pt(messages: list):
if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content})
if new_messages[0]["role"] != "user":
if not new_messages or new_messages[0]["role"] != "user":
if litellm.modify_params:
new_messages.insert(
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
)
else:
raise Exception(
"Invalid first message. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, "
)
if new_messages[-1]["role"] == "assistant":
for content in new_messages[-1]["content"]:
if isinstance(content, dict) and content["type"] == "text":
content["text"] = content[
"text"
].rstrip() # no trailing whitespace for final assistant message
return new_messages
# ------------------------------------------------------------------------------
def convert_to_anthropic_tool_result(message: dict) -> dict:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
"""
"""
Anthropic tool_results look like:
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
"content": "ConnectionError: the weather service API is not available (HTTP 500)",
# "is_error": true
}
]
}
"""
tool_call_id = message.get("tool_call_id")
content = message.get("content")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
anthropic_tool_result = {
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": content,
}
return anthropic_tool_result
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Anthropic tool invokes:
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "<thinking>To answer this question, I will: 1. Use the get_weather tool to get the current weather in San Francisco. 2. Use the get_time tool to get the current time in the America/Los_Angeles timezone, which covers San Francisco, CA.</thinking>"
},
{
"type": "tool_use",
"id": "toolu_01A09q90qw90lq917835lq9",
"name": "get_weather",
"input": {"location": "San Francisco, CA"}
}
]
}
"""
anthropic_tool_invoke = [
{
"type": "tool_use",
"id": get_attribute_or_key(tool, "id"),
"name": get_attribute_or_key(
get_attribute_or_key(tool, "function"), "name"
),
"input": json.loads(
get_attribute_or_key(
get_attribute_or_key(tool, "function"), "arguments"
)
),
}
for tool in tool_calls
if get_attribute_or_key(tool, "type") == "function"
]
return anthropic_tool_invoke
def anthropic_messages_pt(messages: list):
"""
format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
5. System messages are a separate param to the Messages API
6. Ensure we only accept role, content. (message.name is not supported)
"""
# add role=tool support to allow function call result/error submission
user_message_types = {"user", "tool"}
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
new_messages = []
msg_i = 0
while msg_i < len(messages):
user_content = []
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list):
for m in messages[msg_i]["content"]:
if m.get("type", "") == "image_url":
user_content.append(
{
"type": "image",
"source": convert_to_anthropic_image_obj(
m["image_url"]["url"]
),
}
)
elif m.get("type", "") == "text":
user_content.append({"type": "text", "text": m["text"]})
elif messages[msg_i]["role"] == "tool":
# OpenAI's tool message content will always be a string
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
else:
user_content.append(
{"type": "text", "text": messages[msg_i]["content"]}
)
msg_i += 1
if user_content:
new_messages.append({"role": "user", "content": user_content})
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append({"type": "text", "text": assistant_text})
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
)
msg_i += 1
if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content})
if not new_messages or new_messages[0]["role"] != "user":
if litellm.modify_params:
new_messages.insert(
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
@ -784,7 +1007,20 @@ def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
return params
###
### GEMINI HELPER FUNCTIONS ###
def get_system_prompt(messages):
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
return system_prompt, messages
def convert_openai_message_to_cohere_tool_result(message):
@ -842,7 +1078,8 @@ def cohere_message_pt(messages: list):
tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result)
else:
prompt += message["content"]
prompt += message["content"] + "\n\n"
prompt = prompt.rstrip()
return prompt, tool_results
@ -916,12 +1153,6 @@ def _gemini_vision_convert_messages(messages: list):
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
"""
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
try:
# given messages for gpt-4 vision, convert them for gemini
@ -948,6 +1179,12 @@ def _gemini_vision_convert_messages(messages: list):
image = _load_image_from_url(img)
processed_images.append(image)
else:
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
# Case 2: Image filepath (e.g. temp.jpeg) given
image = Image.open(img)
processed_images.append(image)
@ -1087,13 +1324,19 @@ def prompt_factory(
if model == "claude-instant-1" or model == "claude-2":
return anthropic_pt(messages=messages)
return anthropic_messages_pt(messages=messages)
elif custom_llm_provider == "anthropic_xml":
return anthropic_messages_pt_xml(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
if model == "gemini-pro-vision":
if (
model == "gemini-pro-vision"
or litellm.supports_vision(model=model)
or litellm.supports_vision(model=custom_llm_provider + "/" + model)
):
return _gemini_vision_convert_messages(messages=messages)
else:
return gemini_text_image_pt(messages=messages)
@ -1109,6 +1352,13 @@ def prompt_factory(
return anthropic_pt(messages=messages)
elif "mistral." in model:
return mistral_instruct_pt(messages=messages)
elif "llama2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)
elif "llama3" in model and "instruct" in model:
return hf_chat_template(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=messages,
)
elif custom_llm_provider == "perplexity":
for message in messages:
message.pop("name", None)
@ -1118,6 +1368,13 @@ def prompt_factory(
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)
elif (
"meta-llama/llama-3" in model or "meta-llama-3" in model
) and "instruct" in model:
return hf_chat_template(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=messages,
)
elif (
"tiiuae/falcon" in model
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
@ -1158,3 +1415,9 @@ def prompt_factory(
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
def get_attribute_or_key(tool_or_function, attribute, default=None):
if hasattr(tool_or_function, attribute):
return getattr(tool_or_function, attribute)
return tool_or_function.get(attribute, default)

View file

@ -112,10 +112,16 @@ def start_prediction(
}
initial_prediction_data = {
"version": version_id,
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
@ -307,9 +313,7 @@ def completion(
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
)
model_response["ended"] = (
time.time()
) # for pricing this must remain right after calling api
## LOGGING
logging_obj.post_call(
input=prompt,
@ -332,9 +336,12 @@ def completion(
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt))
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", ""),
disallowed_special=(),
)
)
model_response["model"] = "replicate/" + model
usage = Usage(
@ -342,7 +349,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -365,7 +365,10 @@ def completion(
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
completion_response_choices = completion_response[0]
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
else:
completion_response_choices = completion_response
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
@ -396,7 +399,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
@ -580,7 +583,10 @@ async def async_completion(
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
completion_response_choices = completion_response[0]
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
else:
completion_response_choices = completion_response
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
@ -611,7 +617,7 @@ async def async_completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -2,6 +2,7 @@
Deprecated. We now do together ai calls via the openai client.
Reference: https://docs.together.ai/docs/openai-api-compatibility
"""
import os, types
import json
from enum import Enum
@ -225,7 +226,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

File diff suppressed because it is too large Load diff

View file

@ -3,10 +3,10 @@ import json
from enum import Enum
import requests
import time
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx
import httpx, inspect
class VertexAIError(Exception):
@ -22,9 +22,39 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
class ExtendedGenerationConfig(dict):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
class VertexAIConfig:
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
@ -36,6 +66,16 @@ class VertexAIConfig:
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
Note: Please make sure to modify the default parameters as required for your use case.
"""
@ -43,6 +83,11 @@ class VertexAIConfig:
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
def __init__(
self,
@ -50,6 +95,11 @@ class VertexAIConfig:
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -74,6 +124,66 @@ class VertexAIConfig:
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if (
param == "stream" and value == True
): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value
if param == "n":
optional_params["candidate_count"] = value
if param == "stop":
if isinstance(value, str):
optional_params["stop_sequences"] = [value]
elif isinstance(value, list):
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
from vertexai.preview import generative_models
gtool_func_declarations = []
for tool in value:
gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
generative_models.Tool(
function_declarations=gtool_func_declarations
)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
pass
return optional_params
import asyncio
@ -117,8 +227,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
image_bytes = response.content
return image_bytes
except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout)
return b"" # Return an empty bytes object or handle the error as needed
raise Exception(f"An exception occurs with this image - {str(e)}")
def _load_image_from_url(image_url: str):
@ -139,7 +248,8 @@ def _load_image_from_url(image_url: str):
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes)
return Image.from_bytes(data=image_bytes)
def _gemini_vision_convert_messages(messages: list):
@ -257,6 +367,7 @@ def completion(
logging_obj,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -299,7 +410,17 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
creds, _ = google.auth.default(quota_project_id=vertex_project)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
@ -417,13 +538,14 @@ def completion(
return async_completion(**data)
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
stream = optional_params.pop("stream", False)
if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@ -436,12 +558,12 @@ def completion(
model_response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
generation_config=optional_params,
safety_settings=safety_settings,
stream=True,
tools=tools,
)
optional_params["stream"] = True
return model_response
request_str += f"response = llm_model.generate_content({content})\n"
@ -458,7 +580,7 @@ def completion(
## LLM Call
response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
generation_config=optional_params,
safety_settings=safety_settings,
tools=tools,
)
@ -513,7 +635,7 @@ def completion(
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
@ -545,7 +667,7 @@ def completion(
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
@ -670,7 +792,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -697,12 +819,11 @@ async def async_completion(
Add support for acompletion calls for gemini-pro
"""
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
@ -719,14 +840,15 @@ async def async_completion(
)
## LLM Call
# print(f"final content: {content}")
response = await llm_model._generate_content_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
generation_config=optional_params,
tools=tools,
)
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
@ -879,7 +1001,7 @@ async def async_completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -905,8 +1027,6 @@ async def async_streaming(
"""
Add support for async streaming calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
@ -927,11 +1047,10 @@ async def async_streaming(
response = await llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
generation_config=optional_params,
tools=tools,
)
optional_params["stream"] = True
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
@ -950,7 +1069,7 @@ async def async_streaming(
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop(
"stream", None
@ -1046,6 +1165,7 @@ def embedding(
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
print_verbose=None,
):
@ -1066,7 +1186,17 @@ def embedding(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
creds, _ = google.auth.default(quota_project_id=vertex_project)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)

View file

@ -0,0 +1,469 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import os, types
import json
from enum import Enum
import requests, copy
import time, uuid
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
"""
- Run client init
- Support async completion, streaming
"""
# makes headers for API call
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
from google.auth.transport.requests import Request # type: ignore[import-untyped]
if credentials.token is None:
credentials.refresh(Request())
if not credentials.token:
raise RuntimeError("Could not resolve API token from the credentials")
return credentials.token
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
from anthropic import AnthropicVertex
except:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## Format Prompt
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
)
except Exception as e:
raise VertexAIError(status_code=400, message=str(e))
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
print_verbose(f"_is_function_call: {_is_function_call}")
## Completion Call
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
access_token = None
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
region=vertex_location,
access_token=access_token,
)
else:
vertex_ai_client = client
if acompletion == True:
"""
- async streaming
- async completion
"""
if stream is not None and stream == True:
return async_streaming(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
else:
return async_completion(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
if stream is not None and stream == True:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
return response
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_completion(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = await vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def async_streaming(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
logging_obj.post_call(input=messages, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper

View file

@ -104,7 +104,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
@ -186,7 +186,7 @@ def batch_completions(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
final_outputs.append(model_response)
return final_outputs