(Refactor) Code Quality improvement - Use Common base handler for Cohere /generate API (#7122)

* use validate_environment in common utils

* use transform request / response for cohere

* remove unused file

* use cohere base_llm_http_handler

* working cohere generate api on llm http handler

* streaming cohere generate api

* fix get_model_response_iterator

* fix streaming handler

* fix get_model_response_iterator

* test_cohere_generate_api_completion

* fix linting error

* fix testing cohere raising error

* fix get_model_response_iterator type

* add testing cohere generate api
This commit is contained in:
Ishaan Jaff 2024-12-10 10:44:42 -08:00 committed by GitHub
parent 9c2316b7ec
commit 1b377d5229
9 changed files with 439 additions and 382 deletions

View file

@ -411,32 +411,6 @@ class CustomStreamWrapper:
except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
text = ""
is_finished = False
finish_reason = ""
index: Optional[int] = None
if "index" in data_json:
index = data_json.get("index")
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
raise Exception(data_json)
return {
"index": index,
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk):
is_finished = False
finish_reason = ""
@ -1157,11 +1131,6 @@ class CustomStreamWrapper:
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider == "cohere":
response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None:
@ -1669,6 +1638,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "text-completion-codestral"
or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "cohere_chat"
or self.custom_llm_provider == "cohere"
or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "anthropic_text"
or self.custom_llm_provider == "huggingface"