(Refactor) Code Quality improvement - Use Common base handler for anthropic_text/ (#7143)

* add anthropic text provider

* add ANTHROPIC_TEXT to LlmProviders

* fix anthropic text implementation

* working anthropic text claude-2

* test_acompletion_claude2_stream

* add param mapping for anthropic text

* fix unused imports

* fix anthropic completion handler.py
This commit is contained in:
Ishaan Jaff 2024-12-10 12:23:58 -08:00 committed by GitHub
parent 5e016fe66a
commit bdb20821ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 528 additions and 498 deletions

View file

@ -846,6 +846,7 @@ class LlmProviders(str, Enum):
COHERE_CHAT = "cohere_chat" COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai" CLARIFAI = "clarifai"
ANTHROPIC = "anthropic" ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate" REPLICATE = "replicate"
HUGGINGFACE = "huggingface" HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai" TOGETHER_AI = "together_ai"
@ -1060,7 +1061,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
AnthropicExperimentalPassThroughConfig, AnthropicExperimentalPassThroughConfig,
) )
from .llms.groq.stt.transformation import GroqSTTConfig from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion import AnthropicTextConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig from .llms.predibase import PredibaseConfig

View file

@ -12,6 +12,7 @@ LITELLM_CHAT_PROVIDERS = [
"cohere_chat", "cohere_chat",
"clarifai", "clarifai",
"anthropic", "anthropic",
"anthropic_text",
"replicate", "replicate",
"huggingface", "huggingface",
"together_ai", "together_ai",

View file

@ -52,6 +52,39 @@ def handle_cohere_chat_model_custom_llm_provider(
return model, custom_llm_provider return model, custom_llm_provider
def handle_anthropic_text_model_custom_llm_provider(
model: str, custom_llm_provider: Optional[str] = None
) -> Tuple[str, Optional[str]]:
"""
if user sets model = "anthropic/claude-2" -> use custom_llm_provider = "anthropic_text"
Args:
model:
custom_llm_provider:
Returns:
model, custom_llm_provider
"""
if custom_llm_provider:
if (
custom_llm_provider == "anthropic"
and litellm.AnthropicTextConfig._is_anthropic_text_model(model)
):
return model, "anthropic_text"
if "/" in model:
_custom_llm_provider, _model = model.split("/", 1)
if (
_custom_llm_provider
and _custom_llm_provider == "anthropic"
and litellm.AnthropicTextConfig._is_anthropic_text_model(_model)
):
return _model, "anthropic_text"
return model, custom_llm_provider
def get_llm_provider( # noqa: PLR0915 def get_llm_provider( # noqa: PLR0915
model: str, model: str,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
@ -92,6 +125,10 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider model, custom_llm_provider
) )
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
model, custom_llm_provider
)
if custom_llm_provider: if custom_llm_provider:
if ( if (
model.split("/")[0] == custom_llm_provider model.split("/")[0] == custom_llm_provider
@ -210,7 +247,10 @@ def get_llm_provider( # noqa: PLR0915
custom_llm_provider = "text-completion-openai" custom_llm_provider = "text-completion-openai"
## anthropic ## anthropic
elif model in litellm.anthropic_models: elif model in litellm.anthropic_models:
custom_llm_provider = "anthropic" if litellm.AnthropicTextConfig._is_anthropic_text_model(model):
custom_llm_provider = "anthropic_text"
else:
custom_llm_provider = "anthropic"
## cohere ## cohere
elif model in litellm.cohere_models or model in litellm.cohere_embedding_models: elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere" custom_llm_provider = "cohere"
@ -531,7 +571,9 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
) )
elif custom_llm_provider == "galadriel": elif custom_llm_provider == "galadriel":
api_base = ( api_base = (
api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1" api_base
or get_secret("GALADRIEL_API_BASE")
or "https://api.galadriel.com/v1"
) # type: ignore ) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):

View file

@ -223,40 +223,6 @@ class CustomStreamWrapper:
self.holding_chunk = "" self.holding_chunk = ""
return hold, curr_chunk return hold, curr_chunk
def handle_anthropic_text_chunk(self, chunk):
"""
For old anthropic models - claude-1, claude-2.
Claude-3 is handled from within Anthropic.py VIA ModelResponseIterator()
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
text = ""
is_finished = False
finish_reason = None
if str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
type_chunk = data_json.get("type", None)
if type_chunk == "completion":
text = data_json.get("completion")
finish_reason = data_json.get("stop_reason")
if finish_reason is not None:
is_finished = True
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in str_line:
raise ValueError(f"Unable to parse response. Original response: {str_line}")
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
def handle_predibase_chunk(self, chunk): def handle_predibase_chunk(self, chunk):
try: try:
if not isinstance(chunk, str): if not isinstance(chunk, str):
@ -1005,14 +971,6 @@ class CustomStreamWrapper:
setattr(model_response, key, value) setattr(model_response, key, value)
response_obj = anthropic_response_obj response_obj = anthropic_response_obj
elif (
self.custom_llm_provider
and self.custom_llm_provider == "anthropic_text"
):
response_obj = self.handle_anthropic_text_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate": elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk) response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -1,362 +0,0 @@
"""
Translation logic for anthropic's `/v1/complete` endpoint
"""
import json
import os
import time
import types
from enum import Enum
from typing import Callable, Optional
import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AnthropicTextConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int] = (
litellm.max_tokens
) # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call
def validate_environment(api_key, user_headers):
if api_key is None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": api_key,
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
):
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response.choices[0].message.content = completion_response[ # type: ignore
"completion"
]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def async_completion(
self,
model: str,
model_response: ModelResponse,
api_base: str,
logging_obj,
encoding,
headers: dict,
data: dict,
client=None,
):
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC,
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key=headers.get("x-api-key"),
original_response=response.text,
additional_args={"complete_input_dict": data},
)
response = self._process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
async def async_streaming(
self,
model: str,
api_base: str,
logging_obj,
headers: dict,
data: Optional[dict],
client=None,
):
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC,
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return streamwrapper
def completion(
self,
model: str,
messages: list,
api_base: str,
acompletion: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] is True:
if acompletion is True:
return self.async_streaming(
model=model,
api_base=api_base,
logging_obj=logging_obj,
headers=headers,
data=data,
client=None,
)
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
# stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return stream_response
elif acompletion is True:
return self.async_completion(
model=model,
model_response=model_response,
api_base=api_base,
logging_obj=logging_obj,
encoding=encoding,
headers=headers,
data=data,
client=client,
)
else:
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
response = self._process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -0,0 +1,5 @@
"""
Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,307 @@
"""
Translation logic for anthropic's `/v1/complete` endpoint
Litellm provider slug: `anthropic_text/<model_name>`
"""
import json
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
Usage,
)
class AnthropicTextError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
message=self.message,
status_code=self.status_code,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs
class AnthropicTextConfig(BaseConfig):
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int] = (
litellm.max_tokens
) # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
# makes headers for API call
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
_headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": api_key,
}
headers.update(_headers)
return headers
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
prompt = self._get_anthropic_text_prompt_from_messages(
messages=messages, model=model
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
return data
def get_supported_openai_params(self, model: str):
"""
Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete
"""
return [
"stream",
"max_tokens",
"max_completion_tokens",
"stop",
"temperature",
"top_p",
"extra_headers",
"user",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API)
Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig
API Ref: https://docs.anthropic.com/en/api/complete
"""
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "max_completion_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
_value = litellm.AnthropicConfig()._map_stop_sequences(value)
if _value is not None:
optional_params["stop_sequences"] = _value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "user":
optional_params["metadata"] = {"user_id": value}
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:
raise AnthropicTextError(
message=raw_response.text, status_code=raw_response.status_code
)
prompt = self._get_anthropic_text_prompt_from_messages(
messages=messages, model=model
)
if "error" in completion_response:
raise AnthropicTextError(
message=str(completion_response["error"]),
status_code=raw_response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response.choices[0].message.content = completion_response[ # type: ignore
"completion"
]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return AnthropicTextError(
status_code=status_code,
message=error_message,
)
@staticmethod
def _is_anthropic_text_model(model: str) -> bool:
return model == "claude-2" or model == "claude-instant-1"
def _get_anthropic_text_prompt_from_messages(
self, messages: List[AllMessageValues], model: str
) -> str:
custom_prompt_dict = litellm.custom_prompt_dict
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
return str(prompt)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"Not required"
raise NotImplementedError
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return AnthropicTextCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
_chunk_text = chunk.get("completion", None)
if _chunk_text is not None and isinstance(_chunk_text, str):
text = _chunk_text
finish_reason = chunk.get("stop_reason", None)
if finish_reason is not None:
is_finished = True
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")

View file

@ -2839,7 +2839,7 @@ def prompt_factory(
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages) return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
if model == "claude-instant-1" or model == "claude-2": if litellm.AnthropicTextConfig._is_anthropic_text_model(model):
return anthropic_pt(messages=messages) return anthropic_pt(messages=messages)
return anthropic_messages_pt( return anthropic_messages_pt(
messages=messages, model=model, llm_provider=custom_llm_provider messages=messages, model=model, llm_provider=custom_llm_provider

View file

@ -99,7 +99,6 @@ from .llms import (
) )
from .llms.ai21 import completion as ai21 from .llms.ai21 import completion as ai21
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.anthropic.completion import AnthropicTextCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
@ -204,7 +203,6 @@ together_ai_text_completions = TogetherAITextCompletion()
azure_ai_chat_completions = AzureAIChatCompletion() azure_ai_chat_completions = AzureAIChatCompletion()
azure_ai_embedding = AzureAIEmbedding() azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion() azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
@ -464,6 +462,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "sagemaker_chat" or custom_llm_provider == "sagemaker_chat"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "anthropic_text"
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock" or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
@ -1705,6 +1704,41 @@ def completion( # type: ignore # noqa: PLR0915
api_key=clarifai_key, api_key=clarifai_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
elif custom_llm_provider == "anthropic_text":
api_key = (
api_key
or litellm.anthropic_key
or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY")
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/complete"
)
if api_base is not None and not api_base.endswith("/v1/complete"):
api_base += "/v1/complete"
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="anthropic_text",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
api_key = ( api_key = (
api_key api_key
@ -1713,69 +1747,38 @@ def completion( # type: ignore # noqa: PLR0915
or os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
# call /messages
# default route for all anthropic models
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/messages"
)
if (model == "claude-2") or (model == "claude-instant-1"): if api_base is not None and not api_base.endswith("/v1/messages"):
# call anthropic /completion, only use this route for claude-2, claude-instant-1 api_base += "/v1/messages"
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/complete"
)
if api_base is not None and not api_base.endswith("/v1/complete"): response = anthropic_chat_completions.completion(
api_base += "/v1/complete" model=model,
messages=messages,
response = anthropic_text_completions.completion( api_base=api_base,
model=model, acompletion=acompletion,
messages=messages, custom_prompt_dict=litellm.custom_prompt_dict,
api_base=api_base, model_response=model_response,
acompletion=acompletion, print_verbose=print_verbose,
custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params,
model_response=model_response, litellm_params=litellm_params,
print_verbose=print_verbose, logger_fn=logger_fn,
optional_params=optional_params, encoding=encoding, # for calculating input/output tokens
litellm_params=litellm_params, api_key=api_key,
logger_fn=logger_fn, logging_obj=logging,
encoding=encoding, # for calculating input/output tokens headers=headers,
api_key=api_key, timeout=timeout,
logging_obj=logging, client=client,
headers=headers, custom_llm_provider=custom_llm_provider,
) )
else:
# call /messages
# default route for all anthropic models
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/messages"
)
if api_base is not None and not api_base.endswith("/v1/messages"):
api_base += "/v1/messages"
response = anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
timeout=timeout,
client=client,
custom_llm_provider=custom_llm_provider,
)
if optional_params.get("stream", False) or acompletion is True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(

View file

@ -2830,6 +2830,32 @@ def get_optional_params( # noqa: PLR0915
else False else False
), ),
) )
elif custom_llm_provider == "anthropic_text":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
optional_params = litellm.AnthropicTextConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AnthropicTextConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4208,7 +4234,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
if llm_provider == "openai" or llm_provider == "text-completion-openai": if llm_provider == "openai" or llm_provider == "text-completion-openai":
api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
# anthropic # anthropic
elif llm_provider == "anthropic": elif llm_provider == "anthropic" or llm_provider == "anthropic_text":
api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY")
# ai21 # ai21
elif llm_provider == "ai21": elif llm_provider == "ai21":
@ -6251,6 +6277,8 @@ class ProviderConfigManager:
return litellm.ClarifaiConfig() return litellm.ClarifaiConfig()
elif litellm.LlmProviders.ANTHROPIC == provider: elif litellm.LlmProviders.ANTHROPIC == provider:
return litellm.AnthropicConfig() return litellm.AnthropicConfig()
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
return litellm.AnthropicTextConfig()
elif litellm.LlmProviders.VERTEX_AI == provider: elif litellm.LlmProviders.VERTEX_AI == provider:
if "claude" in model: if "claude" in model:
return litellm.VertexAIAnthropicConfig() return litellm.VertexAIAnthropicConfig()

View file

@ -0,0 +1,73 @@
import asyncio
import os
from re import T
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
from litellm.llms.anthropic.chat import ModelResponseIterator
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["claude-2", "anthropic/claude-2"])
async def test_acompletion_claude2(model):
try:
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your goal is generate a joke on the topic user gives.",
},
{"role": "user", "content": "Generate a 3 liner joke for me"},
]
# test without max-tokens
response = await litellm.acompletion(model=model, messages=messages)
# Add any assertions here to check the response
print(response)
print(response.usage)
print(response.usage.completion_tokens)
print(response["usage"]["completion_tokens"])
# print("new cost tracking")
except litellm.InternalServerError:
pytest.skip("model is overloaded.")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_claude2_stream():
try:
litellm.set_verbose = False
messages = [
{
"role": "system",
"content": "Your goal is generate a joke on the topic user gives.",
},
{"role": "user", "content": "Generate a 3 liner joke for me"},
]
# test without max-tokens
response = await litellm.acompletion(
model="anthropic_text/claude-2",
messages=messages,
stream=True,
max_tokens=10,
)
async for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1126,32 +1126,6 @@ def test_completion_mistral_api_modified_input():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_claude2_1():
try:
litellm.set_verbose = True
print("claude2.1 test request")
messages = [
{
"role": "system",
"content": "Your goal is generate a joke on the topic user gives.",
},
{"role": "user", "content": "Generate a 3 liner joke for me"},
]
# test without max-tokens
response = await litellm.acompletion(model="claude-2.1", messages=messages)
# Add any assertions here to check the response
print(response)
print(response.usage)
print(response.usage.completion_tokens)
print(response["usage"]["completion_tokens"])
# print("new cost tracking")
except litellm.InternalServerError:
pytest.skip("model is overloaded.")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_completion_oobabooga(): # def test_completion_oobabooga():
# try: # try:
# response = completion( # response = completion(