fix(ollama.py): fix async completion calls for ollama

This commit is contained in:
Krrish Dholakia 2023-12-13 13:10:25 -08:00
parent 52375e0377
commit 7b8851cce5
7 changed files with 35 additions and 17 deletions

View file

@ -58,7 +58,7 @@ class LangsmithLogger:
"inputs": { "inputs": {
**new_kwargs **new_kwargs
}, },
"outputs": response_obj, "outputs": response_obj.json(),
"session_name": project_name, "session_name": project_name,
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,

View file

@ -219,7 +219,6 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj): async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
try: try:
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
@ -230,12 +229,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
text = await resp.text() text = await resp.text()
raise OllamaError(status_code=resp.status, message=text) raise OllamaError(status_code=resp.status, message=text)
completion_string = ""
async for line in resp.content.iter_any(): async for line in resp.content.iter_any():
if line: if line:
try: try:
json_chunk = line.decode("utf-8") json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n") chunks = json_chunk.split("\n")
completion_string = ""
for chunk in chunks: for chunk in chunks:
if chunk.strip() != "": if chunk.strip() != "":
j = json.loads(chunk) j = json.loads(chunk)
@ -245,14 +244,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
"content": "", "content": "",
"error": j "error": j
} }
raise Exception(f"OllamError - {chunk}")
if "response" in j: if "response" in j:
completion_obj = { completion_obj = {
"role": "assistant", "role": "assistant",
"content": j["response"], "content": j["response"],
} }
completion_string += completion_obj["content"] completion_string = completion_string + completion_obj["content"]
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string model_response["choices"][0]["message"]["content"] = completion_string

View file

@ -624,7 +624,6 @@ def completion(
or "ft:babbage-002" in model or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models or "ft:davinci-002" in model # support for finetuned completion models
): ):
# print("calling custom openai provider")
openai.api_type = "openai" openai.api_type = "openai"
api_base = ( api_base = (
@ -1319,13 +1318,8 @@ def completion(
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING ## LOGGING
if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
return async_generator
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True: if acompletion is True:
return generator return generator
@ -2126,7 +2120,7 @@ def text_completion(
*args, *args,
**all_params, **all_params,
) )
#print(response)
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion" text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None) text_completion_response["created"] = response.get("created", None)

View file

@ -1004,6 +1004,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
### ROUTE THE REQUEST ### ### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list
print(f"ENTERS LLM ROUTER ACOMPLETION")
response = await llm_router.acompletion(**data) response = await llm_router.acompletion(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment = True) response = await llm_router.acompletion(**data, specific_deployment = True)

View file

@ -64,7 +64,7 @@ class ProxyLogging:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
""" """
try: try:
self.call_details["data"] = data self.call_details["data"] = data
self.call_details["call_type"] = call_type self.call_details["call_type"] = call_type
## check if max parallel requests set ## check if max parallel requests set
@ -75,6 +75,7 @@ class ProxyLogging:
api_key=user_api_key_dict.api_key, api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"]) user_api_key_cache=self.call_details["user_api_key_cache"])
print_verbose(f'final data being sent to {call_type} call: {data}')
return data return data
except Exception as e: except Exception as e:
raise e raise e

View file

@ -76,7 +76,7 @@ def test_vertex_ai():
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
@ -97,10 +97,11 @@ def test_vertex_ai_stream():
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 4) test_models = random.sample(test_models, 4)
test_models += litellm.vertex_language_models # always test gemini-pro # test_models += litellm.vertex_language_models # always test gemini-pro
test_models = ["code-gecko@001"]
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
@ -116,7 +117,7 @@ def test_vertex_ai_stream():
assert len(completed_str) > 4 assert len(completed_str) > 4
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream() test_vertex_ai_stream()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_response(): async def test_async_vertexai_response():
@ -127,6 +128,9 @@ async def test_async_vertexai_response():
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
print(f'model being tested in async call: {model}') print(f'model being tested in async call: {model}')
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
@ -147,6 +151,9 @@ async def test_async_vertexai_streaming_response():
test_models = random.sample(test_models, 4) test_models = random.sample(test_models, 4)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]

View file

@ -365,6 +365,13 @@ class ModelResponse(OpenAIObject):
def __setitem__(self, key, value): def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes # Allow dictionary-style assignment of attributes
setattr(self, key, value) setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class Embedding(OpenAIObject): class Embedding(OpenAIObject):
embedding: list = [] embedding: list = []
@ -430,6 +437,13 @@ class EmbeddingResponse(OpenAIObject):
def __setitem__(self, key, value): def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes # Allow dictionary-style assignment of attributes
setattr(self, key, value) setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class TextChoices(OpenAIObject): class TextChoices(OpenAIObject):
def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params): def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params):