Updated cohere v2 passthrough (#9997)

* Add cohere `/v2/chat` pass-through cost tracking support (#8235)

* feat(cohere_passthrough_handler.py): initial working commit with cohere passthrough cost tracking

* fix(v2_transformation.py): support cohere /v2/chat endpoint

* fix: fix linting errors

* fix: fix import

* fix(v2_transformation.py): fix linting error

* test: handle openai exception change
This commit is contained in:
Krish Dholakia 2025-04-14 19:51:01 -07:00 committed by GitHub
parent db857c74d4
commit 2ed593e052
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 742 additions and 20 deletions

View file

@ -104,19 +104,28 @@ class ModelResponseIterator:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
"""
Convert a string chunk to a GenericStreamingChunk
Note: This is used for Cohere pass through streaming logging
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
@ -131,15 +140,7 @@ class ModelResponseIterator:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e: