mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(openai.py-+-azure.py): fix linting issues
This commit is contained in:
parent
311b4f9d2c
commit
1306addfe8
4 changed files with 30 additions and 14 deletions
|
@ -134,8 +134,6 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||||
message=message
|
message=message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(OpenAIError): # type: ignore
|
class OpenAIError(OpenAIError): # type: ignore
|
||||||
def __init__(self, original_exception):
|
def __init__(self, original_exception):
|
||||||
self.status_code = original_exception.http_status
|
self.status_code = original_exception.http_status
|
||||||
|
|
|
@ -136,7 +136,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
|
@ -156,7 +159,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_ad_token: Optional[str]=None, ):
|
azure_ad_token: Optional[str]=None, ):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||||
response = await azure_client.chat.completions.create(**data)
|
response = await azure_client.chat.completions.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -177,7 +183,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
azure_ad_token: Optional[str]=None,
|
azure_ad_token: Optional[str]=None,
|
||||||
):
|
):
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
response = azure_client.chat.completions.create(**data)
|
response = azure_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
|
@ -218,7 +227,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"input": input,
|
"input": input,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, max_retries=max_retries)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
|
|
@ -209,7 +209,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
|
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=None,
|
input=None,
|
||||||
|
@ -317,7 +320,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
"input": input,
|
"input": input,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=data.pop("max_retries", 2))
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=max_retries)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
|
|
@ -1082,7 +1082,7 @@ class Rules:
|
||||||
if callable(rule):
|
if callable(rule):
|
||||||
decision = rule(input)
|
decision = rule(input)
|
||||||
if decision is False:
|
if decision is False:
|
||||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) # type: ignore
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def post_call_rules(self, input: str, model: str):
|
def post_call_rules(self, input: str, model: str):
|
||||||
|
@ -1091,7 +1091,7 @@ class Rules:
|
||||||
if callable(rule):
|
if callable(rule):
|
||||||
decision = rule(input)
|
decision = rule(input)
|
||||||
if decision is False:
|
if decision is False:
|
||||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) # type: ignore
|
||||||
return True
|
return True
|
||||||
|
|
||||||
####### CLIENT ###################
|
####### CLIENT ###################
|
||||||
|
@ -3135,10 +3135,10 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
|
||||||
choice_list.append(choice)
|
choice_list.append(choice)
|
||||||
model_response_object.choices = choice_list
|
model_response_object.choices = choice_list
|
||||||
|
|
||||||
if "usage" in response_object:
|
if "usage" in response_object and response_object["usage"] is not None:
|
||||||
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0)
|
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0)
|
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0)
|
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||||
|
|
||||||
if "id" in response_object:
|
if "id" in response_object:
|
||||||
model_response_object.id = response_object["id"]
|
model_response_object.id = response_object["id"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue