mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge pull request #144 from yujonglee/fix-modelresponse-typing
Fix ModelResponse typing
This commit is contained in:
commit
3d1934a829
4 changed files with 8 additions and 11 deletions
|
@ -4,6 +4,7 @@ import requests
|
|||
from litellm import logging
|
||||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
|
@ -36,7 +37,7 @@ class AnthropicLLM:
|
|||
"x-api-key": self.api_key
|
||||
}
|
||||
|
||||
def completion(self, model: str, messages: list, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
|
||||
def completion(self, model: str, messages: list, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
|
||||
model = model
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
|
||||
for message in messages:
|
||||
|
|
|
@ -5,6 +5,7 @@ import requests
|
|||
from litellm import logging
|
||||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -26,7 +27,7 @@ class HuggingfaceRestAPILLM():
|
|||
if self.api_key != None:
|
||||
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
|
||||
def completion(self, model: str, messages: list, custom_api_base: str, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
|
||||
if custom_api_base:
|
||||
completion_url = custom_api_base
|
||||
elif "HF_API_BASE" in os.environ:
|
||||
|
|
|
@ -40,7 +40,7 @@ def completion(
|
|||
# model specific optional params
|
||||
# used by text-bison only
|
||||
top_k=40, request_timeout=0, # unused var for old version of OpenAI API
|
||||
):
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
model_response = ModelResponse()
|
||||
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
||||
|
|
|
@ -94,14 +94,9 @@ def test_completion_openai():
|
|||
|
||||
response_str = response['choices'][0]['message']['content']
|
||||
response_str_2 = response.choices[0].message.content
|
||||
print(response_str)
|
||||
print(response_str_2)
|
||||
if type(response_str) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
if type(response_str_2) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert response_str == response_str_2
|
||||
assert type(response_str) == str
|
||||
assert len(response_str) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue