forked from phoenix/litellm-mirror
(feat) - add claude 3
This commit is contained in:
parent
b283588eb8
commit
19eb9063fb
4 changed files with 32 additions and 19 deletions
|
@ -20,7 +20,7 @@ class AnthropicError(Exception):
|
|||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
method="POST", url="https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
|
@ -35,9 +35,7 @@ class AnthropicConfig:
|
|||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # anthropic requires a default
|
||||
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
|
@ -46,7 +44,7 @@ class AnthropicConfig:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
|
||||
max_tokens: Optional[int] = 256, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
|
@ -124,6 +122,10 @@ def completion(
|
|||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -134,7 +136,7 @@ def completion(
|
|||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
|
@ -173,7 +175,7 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
|
@ -192,19 +194,14 @@ def completion(
|
|||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
text_content = completion_response["content"][0].get("text", None)
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
|
|
|
@ -1023,7 +1023,7 @@ def completion(
|
|||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
or "https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = anthropic.completion(
|
||||
|
|
|
@ -84,6 +84,22 @@ def test_completion_claude():
|
|||
# test_completion_claude()
|
||||
|
||||
|
||||
def test_completion_claude_3():
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "Hello, world"}]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_mistral_api():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -4200,7 +4200,7 @@ def get_optional_params(
|
|||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens_to_sample"] = max_tokens
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
elif custom_llm_provider == "cohere":
|
||||
## check if unsupported param passed in
|
||||
supported_params = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue