@ -1,19 +1,25 @@
""" Manifest test. """
import json
import asyncio
from typing import cast
import pytest
import requests
from manifest import Manifest , Response
from manifest . caches . noop import NoopCache
from manifest . caches . sqlite import SQLiteCache
from manifest . clients . dummy import DummyClient
from manifest . session import Session
URL = " http://localhost:6000 "
try :
_ = requests . post ( URL + " /params " ) . json ( )
MODEL_ALIVE = True
except Exception :
MODEL_ALIVE = False
@pytest.mark.usefixtures ( " sqlite_cache " )
@pytest.mark.usefixtures ( " session_cache " )
def test_init ( sqlite_cache : str , session_cache : str ) - > None :
def test_init ( sqlite_cache : str ) - > None :
""" Test manifest initialization. """
with pytest . raises ( ValueError ) as exc_info :
Manifest (
@ -32,7 +38,6 @@ def test_init(sqlite_cache: str, session_cache: str) -> None:
assert manifest . client_name == " dummy "
assert isinstance ( manifest . client , DummyClient )
assert isinstance ( manifest . cache , SQLiteCache )
assert manifest . session is None
assert manifest . client . n == 1 # type: ignore
assert manifest . stop_token == " "
@ -41,19 +46,16 @@ def test_init(sqlite_cache: str, session_cache: str) -> None:
cache_name = " noop " ,
n = 3 ,
stop_token = " \n " ,
session_id = " _default " ,
)
assert manifest . client_name == " dummy "
assert isinstance ( manifest . client , DummyClient )
assert isinstance ( manifest . cache , NoopCache )
assert isinstance ( manifest . session , Session )
assert manifest . client . n == 3 # type: ignore
assert manifest . stop_token == " \n "
@pytest.mark.usefixtures ( " sqlite_cache " )
@pytest.mark.usefixtures ( " session_cache " )
def test_change_manifest ( sqlite_cache : str , session_cache : str ) - > None :
def test_change_manifest ( sqlite_cache : str ) - > None :
""" Test manifest change. """
manifest = Manifest (
client_name = " dummy " ,
@ -65,7 +67,6 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
assert manifest . client_name == " dummy "
assert isinstance ( manifest . client , DummyClient )
assert isinstance ( manifest . cache , SQLiteCache )
assert manifest . session is None
assert manifest . client . n == 1 # type: ignore
assert manifest . stop_token == " "
@ -73,18 +74,14 @@ def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
assert manifest . client_name == " dummy "
assert isinstance ( manifest . client , DummyClient )
assert isinstance ( manifest . cache , SQLiteCache )
assert manifest . session is None
assert manifest . client . n == 1 # type: ignore
assert manifest . stop_token == " \n "
@pytest.mark.usefixtures ( " sqlite_cache " )
@pytest.mark.usefixtures ( " session_cache " )
@pytest.mark.parametrize ( " n " , [ 1 , 2 ] )
@pytest.mark.parametrize ( " return_response " , [ True , False ] )
def test_run (
sqlite_cache : str , session_cache : str , n : int , return_response : bool
) - > None :
def test_run ( sqlite_cache : str , n : int , return_response : bool ) - > None :
""" Test manifest run. """
manifest = Manifest (
client_name = " dummy " ,
@ -111,15 +108,12 @@ def test_run(
else :
res = cast ( str , result )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is not None
)
@ -136,16 +130,13 @@ def test_run(
else :
res = cast ( str , result )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
" run_id " : " 34 " ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
" run_id " : " 34 " ,
}
)
is not None
)
@ -162,15 +153,12 @@ def test_run(
else :
res = cast ( str , result )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is not None
)
@ -187,15 +175,12 @@ def test_run(
else :
res = cast ( str , result )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is not None
)
@ -206,12 +191,9 @@ def test_run(
@pytest.mark.usefixtures ( " sqlite_cache " )
@pytest.mark.usefixtures ( " session_cache " )
@pytest.mark.parametrize ( " n " , [ 1 , 2 ] )
@pytest.mark.parametrize ( " return_response " , [ True , False ] )
def test_batch_run (
sqlite_cache : str , session_cache : str , n : int , return_response : bool
) - > None :
def test_batch_run ( sqlite_cache : str , n : int , return_response : bool ) - > None :
""" Test manifest run. """
manifest = Manifest (
client_name = " dummy " ,
@ -233,6 +215,16 @@ def test_batch_run(
else :
res = cast ( str , result )
assert res == [ " hello " ]
assert (
manifest . cache . get (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is not None
)
prompt = [ " Hello is a prompt " , " Hello is a prompt " ]
result = manifest . run ( prompt , return_response = return_response )
@ -243,6 +235,42 @@ def test_batch_run(
else :
res = cast ( str , result )
assert res == [ " hello " , " hello " ]
assert (
manifest . cache . get (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is not None
)
result = manifest . run ( prompt , return_response = True )
res = cast ( Response , result ) . get_response ( manifest . stop_token , is_batch = True )
assert cast ( Response , result ) . is_cached ( )
assert (
manifest . cache . get (
{
" prompt " : " New prompt " ,
" engine " : " dummy " ,
" num_results " : n ,
} ,
)
is None
)
prompt = [ " This is a prompt " , " New prompt " ]
result = manifest . run ( prompt , return_response = return_response )
if return_response :
res = cast ( Response , result ) . get_response (
manifest . stop_token , is_batch = True
)
# Cached because one item is in cache
assert cast ( Response , result ) . is_cached ( )
else :
res = cast ( str , result )
assert res == [ " hello " , " hello " ]
prompt = [ " Hello is a prompt " , " Hello is a prompt " ]
result = manifest . run ( prompt , stop_token = " ll " , return_response = return_response )
@ -253,6 +281,72 @@ def test_batch_run(
assert res == [ " he " , " he " ]
@pytest.mark.usefixtures ( " sqlite_cache " )
def test_abatch_run ( sqlite_cache : str ) - > None :
""" Test manifest run. """
manifest = Manifest (
client_name = " dummy " ,
cache_name = " sqlite " ,
cache_connection = sqlite_cache ,
)
prompt = [ " This is a prompt " ]
result = asyncio . run ( manifest . arun_batch ( prompt , return_response = True ) )
res = cast ( Response , result ) . get_response ( manifest . stop_token , is_batch = True )
assert res == [ " hello " ]
assert (
manifest . cache . get (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
} ,
)
is not None
)
prompt = [ " Hello is a prompt " , " Hello is a prompt " ]
result = asyncio . run ( manifest . arun_batch ( prompt , return_response = True ) )
res = cast ( Response , result ) . get_response ( manifest . stop_token , is_batch = True )
assert res == [ " hello " , " hello " ]
assert (
manifest . cache . get (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
} ,
)
is not None
)
result = asyncio . run ( manifest . arun_batch ( prompt , return_response = True ) )
res = cast ( Response , result ) . get_response ( manifest . stop_token , is_batch = True )
assert cast ( Response , result ) . is_cached ( )
assert (
manifest . cache . get (
{
" prompt " : " New prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
} ,
)
is None
)
prompt = [ " This is a prompt " , " New prompt " ]
result = asyncio . run ( manifest . arun_batch ( prompt , return_response = True ) )
res = cast ( Response , result ) . get_response ( manifest . stop_token , is_batch = True )
# Cached because one item is in cache
assert cast ( Response , result ) . is_cached ( )
assert res == [ " hello " , " hello " ]
prompt = [ " Hello is a prompt " , " Hello is a prompt " ]
result = asyncio . run ( manifest . arun_batch ( prompt , return_response = True ) )
res = cast ( Response , result ) . get_response ( stop_token = " ll " , is_batch = True )
assert res == [ " he " , " he " ]
@pytest.mark.usefixtures ( " sqlite_cache " )
def test_score_run ( sqlite_cache : str ) - > None :
""" Test manifest run. """
@ -264,16 +358,14 @@ def test_score_run(sqlite_cache: str) -> None:
prompt = " This is a prompt "
result = manifest . score_prompt ( prompt )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
" request_type " : " score_prompt " ,
} ,
)
is not None
)
@ -284,20 +376,35 @@ def test_score_run(sqlite_cache: str) -> None:
" item_dtype " : None ,
" response " : { " choices " : [ { " text " : " This is a prompt " , " logprob " : 0.3 } ] } ,
" cached " : False ,
" request_params " : { " prompt " : " This is a prompt " , " engine " : " dummy " } ,
" request_params " : {
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
" request_type " : " score_prompt " ,
} ,
}
prompt_list = [ " Hello is a prompt " , " Hello is another prompt " ]
result = manifest . score_prompt ( prompt_list )
assert (
manifest . cache . get_key (
json . dumps (
{
" prompt " : [ " Hello is a prompt " , " Hello is another prompt " ] ,
" engine " : " dummy " ,
} ,
sort_keys = True ,
)
manifest . cache . get (
{
" prompt " : " Hello is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
" request_type " : " score_prompt " ,
} ,
)
is not None
)
assert (
manifest . cache . get (
{
" prompt " : " Hello is another prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
" request_type " : " score_prompt " ,
} ,
)
is not None
)
@ -316,76 +423,64 @@ def test_score_run(sqlite_cache: str) -> None:
" request_params " : {
" prompt " : [ " Hello is a prompt " , " Hello is another prompt " ] ,
" engine " : " dummy " ,
" num_results " : 1 ,
" request_type " : " score_prompt " ,
} ,
}
@pytest.mark.usefixtures ( " session_cache " )
def test_log_query ( session_cache : str ) - > None :
""" Test manifest session logging. """
manifest = Manifest ( client_name = " dummy " , cache_name = " noop " , session_id = " _default " )
prompt = " This is a prompt "
_ = manifest . run ( prompt , return_response = False )
query_key = {
" prompt " : " This is a prompt " ,
" engine " : " dummy " ,
" num_results " : 1 ,
}
response_key = {
" cached " : False ,
" request_params " : query_key ,
" response " : { " choices " : [ { " text " : " hello " } ] } ,
" generation_key " : " choices " ,
" item_dtype " : None ,
" item_key " : " text " ,
" logits_key " : " token_logprobs " ,
}
assert manifest . get_last_queries ( 1 ) == [ ( " This is a prompt " , " hello " ) ]
assert manifest . get_last_queries ( 1 , return_raw_values = True ) == [
( query_key , response_key )
]
assert manifest . get_last_queries ( 3 , return_raw_values = True ) == [
( query_key , response_key )
]
prior_cache_item = ( query_key , response_key )
prompt_lst = [ " This is a prompt " , " This is a prompt2 " ]
_ = manifest . run ( prompt_lst , return_response = False )
query_key = {
" prompt " : [ " This is a prompt " , " This is a prompt2 " ] ,
" engine " : " dummy " ,
" num_results " : 1 ,
}
response_key = {
" cached " : False ,
" generation_key " : " choices " ,
" item_dtype " : None ,
" item_key " : " text " ,
" logits_key " : " token_logprobs " ,
" request_params " : query_key ,
" response " : { " choices " : [ { " text " : " hello " } , { " text " : " hello " } ] } ,
}
assert manifest . get_last_queries ( 1 ) == [
( [ " This is a prompt " , " This is a prompt2 " ] , [ " hello " , " hello " ] )
]
assert manifest . get_last_queries ( 1 , return_raw_values = True ) == [
( query_key , response_key )
]
assert manifest . get_last_queries ( 3 , return_raw_values = True ) == [
prior_cache_item ,
( query_key , response_key ) ,
]
# Test no session
manifest = Manifest (
client_name = " dummy " ,
cache_name = " noop " ,
@pytest.mark.skipif ( not MODEL_ALIVE , reason = f " No model at { URL } " )
@pytest.mark.usefixtures ( " sqlite_cache " )
def test_local_huggingface ( sqlite_cache : str ) - > None :
""" Test local huggingface client. """
client = Manifest (
client_name = " huggingface " ,
client_connection = URL ,
cache_name = " sqlite " ,
cache_connection = sqlite_cache ,
)
prompt = " This is a prompt "
_ = manifest . run ( prompt , return_response = False )
with pytest . raises ( ValueError ) as exc_info :
manifest . get_last_queries ( 1 )
assert (
str ( exc_info . value )
== " Session was not initialized. Set `session_id` when loading Manifest. "
res = client . run ( " Why are there apples? " )
assert isinstance ( res , str ) and len ( res ) > 0
response = cast ( Response , client . run ( " Why are there apples? " , return_response = True ) )
assert isinstance ( response . get_response ( ) , str ) and len ( response . get_response ( ) ) > 0
assert response . is_cached ( ) is True
response = cast ( Response , client . run ( " Why are there apples? " , return_response = True ) )
assert response . is_cached ( ) is True
res_list = client . run ( [ " Why are there apples? " , " Why are there bananas? " ] )
assert isinstance ( res_list , list ) and len ( res_list ) == 2
response = cast (
Response , client . run ( " Why are there bananas? " , return_response = True )
)
assert response . is_cached ( ) is True
res_list = asyncio . run (
client . arun_batch ( [ " Why are there pears? " , " Why are there oranges? " ] )
)
assert isinstance ( res_list , list ) and len ( res_list ) == 2
response = cast (
Response , client . run ( " Why are there oranges? " , return_response = True )
)
assert response . is_cached ( ) is True
scores = client . score_prompt ( " Why are there apples? " )
assert isinstance ( scores , dict ) and len ( scores ) > 0
assert scores [ " cached " ] is False
assert len ( scores [ " response " ] [ " choices " ] [ 0 ] [ " token_logprobs " ] ) == len (
scores [ " response " ] [ " choices " ] [ 0 ] [ " tokens " ]
)
scores = client . score_prompt ( [ " Why are there apples? " , " Why are there bananas? " ] )
assert isinstance ( scores , dict ) and len ( scores ) > 0
assert scores [ " cached " ] is True
assert len ( scores [ " response " ] [ " choices " ] [ 0 ] [ " token_logprobs " ] ) == len (
scores [ " response " ] [ " choices " ] [ 0 ] [ " tokens " ]
)
assert len ( scores [ " response " ] [ " choices " ] [ 0 ] [ " token_logprobs " ] ) == len (
scores [ " response " ] [ " choices " ] [ 0 ] [ " tokens " ]
)