(feat) litellm.completion - support ollama timeout

This commit is contained in:
ishaan-jaff 2024-01-09 10:31:01 +05:30
parent 10f76ec36c
commit 77027746ba
3 changed files with 49 additions and 17 deletions

View file

@ -558,7 +558,7 @@ class TextChoices(OpenAIObject):
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
@ -737,7 +737,9 @@ class Logging:
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
)
if messages is not None and isinstance(messages, str):
messages = [{"role": "user", "content": messages}] # convert text completion input to the chat completion format
messages = [
{"role": "user", "content": messages}
] # convert text completion input to the chat completion format
self.model = model
self.messages = messages
self.stream = stream
@ -4002,7 +4004,8 @@ def get_llm_provider(
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list
and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
and len(model.split("/"))
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
):
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
@ -4137,15 +4140,15 @@ def get_llm_provider(
raise e
else:
raise litellm.exceptions.BadRequestError( # type: ignore
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
@ -6460,6 +6463,13 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif "Read timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
)
elif custom_llm_provider == "vllm":
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 0: