forked from phoenix/litellm-mirror
fix(anthropic_text.py): add support for async text completion calls
This commit is contained in:
parent
bdf7f6d13c
commit
26286a54b8
6 changed files with 324 additions and 98 deletions
|
@ -8,6 +8,8 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
|||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
|
@ -94,98 +96,13 @@ def validate_environment(api_key, user_headers):
|
|||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
class AnthropicTextCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
def process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
|
@ -221,9 +138,204 @@ def completion(
|
|||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
encoding,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["prompt"],
|
||||
api_key=headers.get("x-api-key"),
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
data: Optional[dict],
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
acompletion: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if acompletion == True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
# stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
elif acompletion == True:
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -58,13 +58,16 @@ class AsyncHTTPHandler:
|
|||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def close(self):
|
||||
|
|
|
@ -67,6 +67,7 @@ from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
|||
from .llms.azure import AzureChatCompletion
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -99,6 +100,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
|||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
anthropic_text_completions = AnthropicTextCompletion()
|
||||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
|
@ -1165,10 +1167,11 @@ def completion(
|
|||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
response = anthropic_text.completion(
|
||||
response = anthropic_text_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
|
|
|
@ -492,6 +492,31 @@ def test_completion_claude2_1():
|
|||
|
||||
# test_completion_claude2_1()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude2_1():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("claude2.1 test request")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your goal is generate a joke on the topic user gives.",
|
||||
},
|
||||
{"role": "user", "content": "Generate a 3 liner joke for me"},
|
||||
]
|
||||
# test without max tokens
|
||||
response = await litellm.acompletion(model="claude-2.1", messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# def test_completion_oobabooga():
|
||||
# try:
|
||||
# response = completion(
|
||||
|
|
|
@ -380,6 +380,51 @@ def test_completion_claude_stream():
|
|||
|
||||
|
||||
# test_completion_claude_stream()
|
||||
def test_completion_claude_2_stream():
|
||||
litellm.set_verbose = True
|
||||
response = completion(
|
||||
model="claude-2",
|
||||
messages=[{"role": "user", "content": "hello from litellm"}],
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
idx = 0
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
# print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude_2_stream():
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.acompletion(
|
||||
model="claude-2",
|
||||
messages=[{"role": "user", "content": "hello from litellm"}],
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
# print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
|
||||
|
||||
def test_completion_palm_stream():
|
||||
|
|
|
@ -8810,6 +8810,35 @@ class CustomStreamWrapper:
|
|||
self.holding_chunk = ""
|
||||
return hold, curr_chunk
|
||||
|
||||
def handle_anthropic_text_chunk(self, chunk):
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
type_chunk = data_json.get("type", None)
|
||||
if type_chunk == "completion":
|
||||
text = data_json.get("completion")
|
||||
finish_reason = data_json.get("stop_reason")
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "error" in str_line:
|
||||
raise ValueError(f"Unable to parse response. Original response: {str_line}")
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_anthropic_chunk(self, chunk):
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
|
@ -9497,6 +9526,14 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif (
|
||||
self.custom_llm_provider
|
||||
and self.custom_llm_provider == "anthropic_text"
|
||||
):
|
||||
response_obj = self.handle_anthropic_text_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
response_obj = self.handle_replicate_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -10074,6 +10111,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "text-completion-openai"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider == "anthropic_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
or self.custom_llm_provider == "ollama"
|
||||
or self.custom_llm_provider == "ollama_chat"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue