mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(feat) litellm.completion - support ollama timeout
This commit is contained in:
parent
10f76ec36c
commit
77027746ba
3 changed files with 49 additions and 17 deletions
|
@ -179,10 +179,7 @@ def get_ollama_response(
|
|||
elif optional_params.get("stream", False) == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
response = requests.post(url=f"{url}", json=data, timeout=litellm.request_timeout)
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ def test_hanging_request_azure():
|
|||
)
|
||||
|
||||
|
||||
test_hanging_request_azure()
|
||||
# test_hanging_request_azure()
|
||||
|
||||
|
||||
def test_hanging_request_openai():
|
||||
|
@ -156,3 +156,28 @@ def test_timeout_streaming():
|
|||
|
||||
|
||||
# test_timeout_streaming()
|
||||
|
||||
|
||||
def test_timeout_ollama():
|
||||
# this Will Raise a timeout
|
||||
import litellm
|
||||
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
litellm.request_timeout = 0.1
|
||||
litellm.set_verbose = True
|
||||
response = litellm.completion(
|
||||
model="ollama/phi",
|
||||
messages=[{"role": "user", "content": "hello, what llm are u"}],
|
||||
max_tokens=1,
|
||||
api_base="https://test-ollama-endpoint.onrender.com",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
litellm.request_timeout = None
|
||||
print(response)
|
||||
except openai.APITimeoutError as e:
|
||||
print("got a timeout error! Passed ! ")
|
||||
pass
|
||||
|
||||
|
||||
# test_timeout_ollama()
|
||||
|
|
|
@ -558,7 +558,7 @@ class TextChoices(OpenAIObject):
|
|||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
|
@ -737,7 +737,9 @@ class Logging:
|
|||
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
|
||||
)
|
||||
if messages is not None and isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}] # convert text completion input to the chat completion format
|
||||
messages = [
|
||||
{"role": "user", "content": messages}
|
||||
] # convert text completion input to the chat completion format
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.stream = stream
|
||||
|
@ -4002,7 +4004,8 @@ def get_llm_provider(
|
|||
if (
|
||||
model.split("/", 1)[0] in litellm.provider_list
|
||||
and model.split("/", 1)[0] not in litellm.model_list
|
||||
and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
|
||||
and len(model.split("/"))
|
||||
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
|
||||
):
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
|
@ -4137,15 +4140,15 @@ def get_llm_provider(
|
|||
raise e
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError( # type: ignore
|
||||
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
|
||||
model=model,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content=error_str,
|
||||
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
llm_provider="",
|
||||
)
|
||||
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
|
||||
model=model,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content=error_str,
|
||||
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
llm_provider="",
|
||||
)
|
||||
|
||||
|
||||
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||
|
@ -6460,6 +6463,13 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "Read timed out" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"OllamaException: {original_exception}",
|
||||
llm_provider="ollama",
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "vllm":
|
||||
if hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue