(feat) litellm.completion - support ollama timeout

This commit is contained in:
ishaan-jaff 2024-01-09 10:31:01 +05:30
parent 10f76ec36c
commit 77027746ba
3 changed files with 49 additions and 17 deletions

View file

@ -179,10 +179,7 @@ def get_ollama_response(
elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
)
response = requests.post(url=f"{url}", json=data, timeout=litellm.request_timeout)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)

View file

@ -80,7 +80,7 @@ def test_hanging_request_azure():
)
test_hanging_request_azure()
# test_hanging_request_azure()
def test_hanging_request_openai():
@ -156,3 +156,28 @@ def test_timeout_streaming():
# test_timeout_streaming()
def test_timeout_ollama():
# this Will Raise a timeout
import litellm
litellm.set_verbose = True
try:
litellm.request_timeout = 0.1
litellm.set_verbose = True
response = litellm.completion(
model="ollama/phi",
messages=[{"role": "user", "content": "hello, what llm are u"}],
max_tokens=1,
api_base="https://test-ollama-endpoint.onrender.com",
)
# Add any assertions here to check the response
litellm.request_timeout = None
print(response)
except openai.APITimeoutError as e:
print("got a timeout error! Passed ! ")
pass
# test_timeout_ollama()

View file

@ -558,7 +558,7 @@ class TextChoices(OpenAIObject):
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
@ -737,7 +737,9 @@ class Logging:
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
)
if messages is not None and isinstance(messages, str):
messages = [{"role": "user", "content": messages}] # convert text completion input to the chat completion format
messages = [
{"role": "user", "content": messages}
] # convert text completion input to the chat completion format
self.model = model
self.messages = messages
self.stream = stream
@ -4002,7 +4004,8 @@ def get_llm_provider(
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list
and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
and len(model.split("/"))
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
):
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
@ -4137,15 +4140,15 @@ def get_llm_provider(
raise e
else:
raise litellm.exceptions.BadRequestError( # type: ignore
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
@ -6460,6 +6463,13 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif "Read timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
)
elif custom_llm_provider == "vllm":
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 0: