mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
refactor(azure.py): moving azure openai calls to http calls
This commit is contained in:
parent
01a7660a12
commit
53abc31c27
7 changed files with 309 additions and 78 deletions
|
@ -2881,8 +2881,6 @@ def exception_type(
|
|||
llm_provider="openrouter"
|
||||
)
|
||||
original_exception.llm_provider = "openrouter"
|
||||
elif custom_llm_provider == "azure":
|
||||
original_exception.llm_provider = "azure"
|
||||
else:
|
||||
original_exception.llm_provider = "openai"
|
||||
if "This model's maximum context length is" in original_exception._message:
|
||||
|
@ -3478,6 +3476,9 @@ def exception_type(
|
|||
raise original_exception
|
||||
raise original_exception
|
||||
elif custom_llm_provider == "ollama":
|
||||
if "no attribute 'async_get_ollama_response_stream" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
|
||||
if isinstance(original_exception, dict):
|
||||
error_str = original_exception.get("error", "")
|
||||
else:
|
||||
|
@ -3512,9 +3513,59 @@ def exception_type(
|
|||
llm_provider="vllm",
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
if "no attribute 'async_get_ollama_response_stream" in error_str:
|
||||
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
|
||||
elif custom_llm_provider == "azure":
|
||||
if "This model's maximum context length is" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
llm_provider="azure",
|
||||
model=model
|
||||
)
|
||||
elif "invalid_request_error" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
llm_provider="azure",
|
||||
model=model
|
||||
)
|
||||
elif hasattr(original_exception, "status_code"):
|
||||
exception_mapping_worked = True
|
||||
if original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
llm_provider="azure",
|
||||
model=model
|
||||
)
|
||||
elif original_exception.status_code == 408:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="azure"
|
||||
)
|
||||
if original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="azure",
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="azure",
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
message=f"AzureException - {original_exception.message}",
|
||||
llm_provider="azure",
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
|
||||
if hasattr(original_exception, "status_code"):
|
||||
exception_mapping_worked = True
|
||||
|
@ -3853,6 +3904,26 @@ class CustomStreamWrapper:
|
|||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_azure_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
text = ""
|
||||
if chunk.startswith("data:"):
|
||||
data_json = json.loads(chunk[5:]) # chunk.startswith("data:"):
|
||||
try:
|
||||
text = data_json["choices"][0]["delta"].get("content", "")
|
||||
if data_json["choices"][0].get("finish_reason", None):
|
||||
is_finished = True
|
||||
finish_reason = data_json["choices"][0]["finish_reason"]
|
||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
elif "error" in chunk:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||
|
||||
def handle_replicate_chunk(self, chunk):
|
||||
try:
|
||||
text = ""
|
||||
|
@ -4013,6 +4084,12 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
|
||||
chunk = next(self.completion_stream)
|
||||
response_obj = self.handle_azure_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
||||
chunk = next(self.completion_stream)
|
||||
response_obj = self.handle_maritalk_chunk(chunk)
|
||||
|
@ -4187,7 +4264,7 @@ class TextCompletionStreamWrapper:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
print(f"got exception {e}")
|
||||
print(f"got exception {e}") # noqa
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue