forked from phoenix/litellm-mirror
fix(ollama.py): fix async completion calls for ollama
This commit is contained in:
parent
52375e0377
commit
7b8851cce5
7 changed files with 35 additions and 17 deletions
|
@ -58,7 +58,7 @@ class LangsmithLogger:
|
|||
"inputs": {
|
||||
**new_kwargs
|
||||
},
|
||||
"outputs": response_obj,
|
||||
"outputs": response_obj.json(),
|
||||
"session_name": project_name,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
|
|
|
@ -219,7 +219,6 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
|
||||
|
@ -230,12 +229,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
completion_string = ""
|
||||
async for line in resp.content.iter_any():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
completion_string = ""
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
|
@ -245,14 +244,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
"content": "",
|
||||
"error": j
|
||||
}
|
||||
raise Exception(f"OllamError - {chunk}")
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": j["response"],
|
||||
}
|
||||
completion_string += completion_obj["content"]
|
||||
completion_string = completion_string + completion_obj["content"]
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
model_response["choices"][0]["message"]["content"] = completion_string
|
||||
|
|
|
@ -624,7 +624,6 @@ def completion(
|
|||
or "ft:babbage-002" in model
|
||||
or "ft:davinci-002" in model # support for finetuned completion models
|
||||
):
|
||||
# print("calling custom openai provider")
|
||||
openai.api_type = "openai"
|
||||
|
||||
api_base = (
|
||||
|
@ -1319,13 +1318,8 @@ def completion(
|
|||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||
## LOGGING
|
||||
if kwargs.get('acompletion', False) == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
||||
return async_generator
|
||||
|
||||
## LOGGING
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
if acompletion is True:
|
||||
return generator
|
||||
|
@ -2126,7 +2120,7 @@ def text_completion(
|
|||
*args,
|
||||
**all_params,
|
||||
)
|
||||
#print(response)
|
||||
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
|
|
|
@ -1004,6 +1004,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
print(f"ENTERS LLM ROUTER ACOMPLETION")
|
||||
response = await llm_router.acompletion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.acompletion(**data, specific_deployment = True)
|
||||
|
|
|
@ -75,6 +75,7 @@ class ProxyLogging:
|
|||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
print_verbose(f'final data being sent to {call_type} call: {data}')
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -76,7 +76,7 @@ def test_vertex_ai():
|
|||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
|
@ -97,10 +97,11 @@ def test_vertex_ai_stream():
|
|||
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 4)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
# test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
test_models = ["code-gecko@001"]
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
|
@ -116,7 +117,7 @@ def test_vertex_ai_stream():
|
|||
assert len(completed_str) > 4
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai_stream()
|
||||
test_vertex_ai_stream()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertexai_response():
|
||||
|
@ -127,6 +128,9 @@ async def test_async_vertexai_response():
|
|||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
print(f'model being tested in async call: {model}')
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
@ -147,6 +151,9 @@ async def test_async_vertexai_streaming_response():
|
|||
test_models = random.sample(test_models, 4)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
|
|
@ -366,6 +366,13 @@ class ModelResponse(OpenAIObject):
|
|||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
class Embedding(OpenAIObject):
|
||||
embedding: list = []
|
||||
index: int
|
||||
|
@ -431,6 +438,13 @@ class EmbeddingResponse(OpenAIObject):
|
|||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
class TextChoices(OpenAIObject):
|
||||
def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params):
|
||||
super(TextChoices, self).__init__(**params)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue