mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Refactor) Code Quality improvement - Use Common base handler for anthropic_text/
(#7143)
* add anthropic text provider * add ANTHROPIC_TEXT to LlmProviders * fix anthropic text implementation * working anthropic text claude-2 * test_acompletion_claude2_stream * add param mapping for anthropic text * fix unused imports * fix anthropic completion handler.py
This commit is contained in:
parent
5e016fe66a
commit
bdb20821ea
12 changed files with 528 additions and 498 deletions
|
@ -846,6 +846,7 @@ class LlmProviders(str, Enum):
|
|||
COHERE_CHAT = "cohere_chat"
|
||||
CLARIFAI = "clarifai"
|
||||
ANTHROPIC = "anthropic"
|
||||
ANTHROPIC_TEXT = "anthropic_text"
|
||||
REPLICATE = "replicate"
|
||||
HUGGINGFACE = "huggingface"
|
||||
TOGETHER_AI = "together_ai"
|
||||
|
@ -1060,7 +1061,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
|||
AnthropicExperimentalPassThroughConfig,
|
||||
)
|
||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.anthropic.completion import AnthropicTextConfig
|
||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||
from .llms.predibase import PredibaseConfig
|
||||
|
|
|
@ -12,6 +12,7 @@ LITELLM_CHAT_PROVIDERS = [
|
|||
"cohere_chat",
|
||||
"clarifai",
|
||||
"anthropic",
|
||||
"anthropic_text",
|
||||
"replicate",
|
||||
"huggingface",
|
||||
"together_ai",
|
||||
|
|
|
@ -52,6 +52,39 @@ def handle_cohere_chat_model_custom_llm_provider(
|
|||
return model, custom_llm_provider
|
||||
|
||||
|
||||
def handle_anthropic_text_model_custom_llm_provider(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
if user sets model = "anthropic/claude-2" -> use custom_llm_provider = "anthropic_text"
|
||||
|
||||
Args:
|
||||
model:
|
||||
custom_llm_provider:
|
||||
|
||||
Returns:
|
||||
model, custom_llm_provider
|
||||
"""
|
||||
|
||||
if custom_llm_provider:
|
||||
if (
|
||||
custom_llm_provider == "anthropic"
|
||||
and litellm.AnthropicTextConfig._is_anthropic_text_model(model)
|
||||
):
|
||||
return model, "anthropic_text"
|
||||
|
||||
if "/" in model:
|
||||
_custom_llm_provider, _model = model.split("/", 1)
|
||||
if (
|
||||
_custom_llm_provider
|
||||
and _custom_llm_provider == "anthropic"
|
||||
and litellm.AnthropicTextConfig._is_anthropic_text_model(_model)
|
||||
):
|
||||
return _model, "anthropic_text"
|
||||
|
||||
return model, custom_llm_provider
|
||||
|
||||
|
||||
def get_llm_provider( # noqa: PLR0915
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
|
@ -92,6 +125,10 @@ def get_llm_provider( # noqa: PLR0915
|
|||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_llm_provider:
|
||||
if (
|
||||
model.split("/")[0] == custom_llm_provider
|
||||
|
@ -210,7 +247,10 @@ def get_llm_provider( # noqa: PLR0915
|
|||
custom_llm_provider = "text-completion-openai"
|
||||
## anthropic
|
||||
elif model in litellm.anthropic_models:
|
||||
custom_llm_provider = "anthropic"
|
||||
if litellm.AnthropicTextConfig._is_anthropic_text_model(model):
|
||||
custom_llm_provider = "anthropic_text"
|
||||
else:
|
||||
custom_llm_provider = "anthropic"
|
||||
## cohere
|
||||
elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
|
||||
custom_llm_provider = "cohere"
|
||||
|
@ -531,7 +571,9 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
)
|
||||
elif custom_llm_provider == "galadriel":
|
||||
api_base = (
|
||||
api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1"
|
||||
api_base
|
||||
or get_secret("GALADRIEL_API_BASE")
|
||||
or "https://api.galadriel.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
|
|
|
@ -223,40 +223,6 @@ class CustomStreamWrapper:
|
|||
self.holding_chunk = ""
|
||||
return hold, curr_chunk
|
||||
|
||||
def handle_anthropic_text_chunk(self, chunk):
|
||||
"""
|
||||
For old anthropic models - claude-1, claude-2.
|
||||
|
||||
Claude-3 is handled from within Anthropic.py VIA ModelResponseIterator()
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
type_chunk = data_json.get("type", None)
|
||||
if type_chunk == "completion":
|
||||
text = data_json.get("completion")
|
||||
finish_reason = data_json.get("stop_reason")
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "error" in str_line:
|
||||
raise ValueError(f"Unable to parse response. Original response: {str_line}")
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_predibase_chunk(self, chunk):
|
||||
try:
|
||||
if not isinstance(chunk, str):
|
||||
|
@ -1005,14 +971,6 @@ class CustomStreamWrapper:
|
|||
setattr(model_response, key, value)
|
||||
|
||||
response_obj = anthropic_response_obj
|
||||
elif (
|
||||
self.custom_llm_provider
|
||||
and self.custom_llm_provider == "anthropic_text"
|
||||
):
|
||||
response_obj = self.handle_anthropic_text_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
response_obj = self.handle_replicate_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
|
|
@ -1,362 +0,0 @@
|
|||
"""
|
||||
Translation logic for anthropic's `/v1/complete` endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
|
||||
class AnthropicError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AnthropicTextConfig:
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(api_key, user_headers):
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
|
||||
)
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
if user_headers is not None and isinstance(user_headers, dict):
|
||||
headers = {**headers, **user_headers}
|
||||
return headers
|
||||
|
||||
|
||||
class AnthropicTextCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except Exception:
|
||||
raise AnthropicError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"completion"
|
||||
]
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
encoding,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC,
|
||||
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
|
||||
)
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["prompt"],
|
||||
api_key=headers.get("x-api-key"),
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
data: Optional[dict],
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC,
|
||||
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
|
||||
)
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
acompletion: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
if acompletion is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
# stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
elif acompletion is True:
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
5
litellm/llms/anthropic/completion/handler.py
Normal file
5
litellm/llms/anthropic/completion/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
307
litellm/llms/anthropic/completion/transformation.py
Normal file
307
litellm/llms/anthropic/completion/transformation.py
Normal file
|
@ -0,0 +1,307 @@
|
|||
"""
|
||||
Translation logic for anthropic's `/v1/complete` endpoint
|
||||
|
||||
Litellm provider slug: `anthropic_text/<model_name>`
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
status_code=self.status_code,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AnthropicTextConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
|
||||
)
|
||||
_headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
headers.update(_headers)
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
"""
|
||||
Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
return [
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API)
|
||||
|
||||
Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig
|
||||
API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||
_value = litellm.AnthropicConfig()._map_stop_sequences(value)
|
||||
if _value is not None:
|
||||
optional_params["stop_sequences"] = _value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise AnthropicTextError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicTextError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"completion"
|
||||
]
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AnthropicTextError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_text_model(model: str) -> bool:
|
||||
return model == "claude-2" or model == "claude-instant-1"
|
||||
|
||||
def _get_anthropic_text_prompt_from_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> str:
|
||||
custom_prompt_dict = litellm.custom_prompt_dict
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
return str(prompt)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"Not required"
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return AnthropicTextCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
index = int(chunk.get("index", 0))
|
||||
_chunk_text = chunk.get("completion", None)
|
||||
if _chunk_text is not None and isinstance(_chunk_text, str):
|
||||
text = _chunk_text
|
||||
finish_reason = chunk.get("stop_reason", None)
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
|
@ -2839,7 +2839,7 @@ def prompt_factory(
|
|||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(model=model, messages=messages)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
if model == "claude-instant-1" or model == "claude-2":
|
||||
if litellm.AnthropicTextConfig._is_anthropic_text_model(model):
|
||||
return anthropic_pt(messages=messages)
|
||||
return anthropic_messages_pt(
|
||||
messages=messages, model=model, llm_provider=custom_llm_provider
|
||||
|
|
129
litellm/main.py
129
litellm/main.py
|
@ -99,7 +99,6 @@ from .llms import (
|
|||
)
|
||||
from .llms.ai21 import completion as ai21
|
||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||
|
@ -204,7 +203,6 @@ together_ai_text_completions = TogetherAITextCompletion()
|
|||
azure_ai_chat_completions = AzureAIChatCompletion()
|
||||
azure_ai_embedding = AzureAIEmbedding()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
anthropic_text_completions = AnthropicTextCompletion()
|
||||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
|
@ -464,6 +462,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "sagemaker_chat"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "anthropic_text"
|
||||
or custom_llm_provider == "predibase"
|
||||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider == "databricks"
|
||||
|
@ -1705,6 +1704,41 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_key=clarifai_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
elif custom_llm_provider == "anthropic_text":
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.anthropic_key
|
||||
or litellm.api_key
|
||||
or os.environ.get("ANTHROPIC_API_KEY")
|
||||
)
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or get_secret("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
|
||||
if api_base is not None and not api_base.endswith("/v1/complete"):
|
||||
api_base += "/v1/complete"
|
||||
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="anthropic_text",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
api_key = (
|
||||
api_key
|
||||
|
@ -1713,69 +1747,38 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or os.environ.get("ANTHROPIC_API_KEY")
|
||||
)
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
# call /messages
|
||||
# default route for all anthropic models
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or get_secret("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
|
||||
if (model == "claude-2") or (model == "claude-instant-1"):
|
||||
# call anthropic /completion, only use this route for claude-2, claude-instant-1
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or get_secret("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
if api_base is not None and not api_base.endswith("/v1/messages"):
|
||||
api_base += "/v1/messages"
|
||||
|
||||
if api_base is not None and not api_base.endswith("/v1/complete"):
|
||||
api_base += "/v1/complete"
|
||||
|
||||
response = anthropic_text_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
# call /messages
|
||||
# default route for all anthropic models
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or get_secret("ANTHROPIC_BASE_URL")
|
||||
or "https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
|
||||
if api_base is not None and not api_base.endswith("/v1/messages"):
|
||||
api_base += "/v1/messages"
|
||||
|
||||
response = anthropic_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
response = anthropic_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
|
|
@ -2830,6 +2830,32 @@ def get_optional_params( # noqa: PLR0915
|
|||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "anthropic_text":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
optional_params = litellm.AnthropicTextConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AnthropicTextConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "cohere":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -4208,7 +4234,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
|||
if llm_provider == "openai" or llm_provider == "text-completion-openai":
|
||||
api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
# anthropic
|
||||
elif llm_provider == "anthropic":
|
||||
elif llm_provider == "anthropic" or llm_provider == "anthropic_text":
|
||||
api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY")
|
||||
# ai21
|
||||
elif llm_provider == "ai21":
|
||||
|
@ -6251,6 +6277,8 @@ class ProviderConfigManager:
|
|||
return litellm.ClarifaiConfig()
|
||||
elif litellm.LlmProviders.ANTHROPIC == provider:
|
||||
return litellm.AnthropicConfig()
|
||||
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
|
||||
return litellm.AnthropicTextConfig()
|
||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||
if "claude" in model:
|
||||
return litellm.VertexAIAnthropicConfig()
|
||||
|
|
73
tests/llm_translation/test_anthropic_text_completion.py
Normal file
73
tests/llm_translation/test_anthropic_text_completion.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import asyncio
|
||||
import os
|
||||
from re import T
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["claude-2", "anthropic/claude-2"])
|
||||
async def test_acompletion_claude2(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your goal is generate a joke on the topic user gives.",
|
||||
},
|
||||
{"role": "user", "content": "Generate a 3 liner joke for me"},
|
||||
]
|
||||
# test without max-tokens
|
||||
response = await litellm.acompletion(model=model, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("model is overloaded.")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude2_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your goal is generate a joke on the topic user gives.",
|
||||
},
|
||||
{"role": "user", "content": "Generate a 3 liner joke for me"},
|
||||
]
|
||||
# test without max-tokens
|
||||
response = await litellm.acompletion(
|
||||
model="anthropic_text/claude-2",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
|
@ -1126,32 +1126,6 @@ def test_completion_mistral_api_modified_input():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude2_1():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("claude2.1 test request")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your goal is generate a joke on the topic user gives.",
|
||||
},
|
||||
{"role": "user", "content": "Generate a 3 liner joke for me"},
|
||||
]
|
||||
# test without max-tokens
|
||||
response = await litellm.acompletion(model="claude-2.1", messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("model is overloaded.")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# def test_completion_oobabooga():
|
||||
# try:
|
||||
# response = completion(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue