mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
move cohere to http endpoint
This commit is contained in:
parent
70a36740bc
commit
3b4064a58f
8 changed files with 175 additions and 45 deletions
|
@ -1854,7 +1854,7 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
llm_provider="replicate",
|
||||
model=model
|
||||
)
|
||||
elif model in litellm.cohere_models: # Cohere
|
||||
elif model in litellm.cohere_models or custom_llm_provider == "cohere": # Cohere
|
||||
if (
|
||||
"invalid api token" in error_str
|
||||
or "No API key provided." in error_str
|
||||
|
@ -1872,6 +1872,21 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
model=model,
|
||||
llm_provider="cohere",
|
||||
)
|
||||
elif hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 400 or original_exception.status_code == 498:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=f"CohereException - {original_exception.message}",
|
||||
llm_provider="cohere",
|
||||
model=model
|
||||
)
|
||||
elif original_exception.status_code == 500:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"CohereException - {original_exception.message}",
|
||||
llm_provider="cohere",
|
||||
model=model
|
||||
)
|
||||
elif (
|
||||
"CohereConnectionError" in exception_type
|
||||
): # cohere seems to fire these errors when we load test it (1k+ messages / min)
|
||||
|
@ -2287,14 +2302,10 @@ class CustomStreamWrapper:
|
|||
self.model = model
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
self.logging_obj = logging_obj
|
||||
self.completion_stream = completion_stream
|
||||
if self.logging_obj:
|
||||
# Log the type of the received item
|
||||
self.logging_obj.post_call(str(type(completion_stream)))
|
||||
if model in litellm.cohere_models:
|
||||
# these do not return an iterator, so we need to wrap it in one
|
||||
self.completion_stream = iter(completion_stream)
|
||||
else:
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -2359,6 +2370,16 @@ class CustomStreamWrapper:
|
|||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_cohere_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
print(f"cohere chunk: {chunk}")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
print(f"data json: {data_json}")
|
||||
return data_json["text"]
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_openai_text_completion_chunk(self, chunk):
|
||||
try:
|
||||
return chunk["choices"][0]["text"]
|
||||
|
@ -2416,9 +2437,6 @@ class CustomStreamWrapper:
|
|||
if text_data == "":
|
||||
return self.__next__()
|
||||
completion_obj["content"] = text_data
|
||||
elif self.model in litellm.cohere_models:
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = chunk.text
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
|
||||
|
@ -2440,6 +2458,9 @@ class CustomStreamWrapper:
|
|||
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_nlp_cloud_chunk(chunk)
|
||||
elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_cohere_chunk(chunk)
|
||||
else: # openai chat/azure models
|
||||
chunk = next(self.completion_stream)
|
||||
return chunk # open ai returns finish_reason, we should just return the openai chunk
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue