diff --git a/programming_runs/generators/generator_utils.py b/programming_runs/generators/generator_utils.py index adc6019..ea81fe6 100644 --- a/programming_runs/generators/generator_utils.py +++ b/programming_runs/generators/generator_utils.py @@ -1,4 +1,4 @@ -from generators.model import ModelBase, Message, message_to_str, messages_to_str +from generators.model import ModelBase, Message import random from typing import Union, List, Optional, Callable @@ -41,7 +41,7 @@ def generic_generate_func_impl( content=prompt, ), Message( - role="user", # TODO: check this + role="user", # TODO: check this content=reflexion_few_shot, ), Message( @@ -61,8 +61,7 @@ def generic_generate_func_impl( content=f"[improved impl]:\n{func_sig}", ), ] - func_bodies = model.generate_chat( - messages=messages, num_comps=num_comps, temperature=temperature) + func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature) else: system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}" print_messages(system_prompt, func_sig) @@ -76,8 +75,7 @@ def generic_generate_func_impl( content=func_sig, ), ] - func_bodies = model.generate_chat( - messages=messages, num_comps=num_comps, temperature=temperature) + func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature) else: if strategy == "reflexion": prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}" @@ -95,8 +93,7 @@ def generic_generate_func_impl( return func_body_str else: - func_bodies = [parse_code_block(func_body) - for func_body in func_bodies] + func_bodies = [parse_code_block(func_body) for func_body in func_bodies] print_generated_func_body("\n\n".join(func_bodies)) return func_bodies @@ -105,7 +102,7 @@ def generic_generate_internal_tests( func_sig: str, model: ModelBase, max_num_tests: int, - test_generation_few_shot: List[Message], + test_generation_few_shot: str, test_generation_chat_instruction: str, test_generation_completion_instruction: str, parse_tests: Callable[[str], List[str]], @@ -122,7 +119,7 @@ def generic_generate_internal_tests( ), Message( role="user", - content=f"{messages_to_str(test_generation_few_shot)}\n\n[func signature]:\n{func_sig}\n\n[think]:" + content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:" ) ] output = model.generate_chat(messages=messages, max_tokens=1024) @@ -133,10 +130,9 @@ def generic_generate_internal_tests( role="system", content=test_generation_chat_instruction, ), - ] + test_generation_few_shot + [ Message( role="user", - content=f"{func_sig}" + content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:", ) ] output = model.generate_chat(messages=messages, max_tokens=1024) @@ -197,7 +193,6 @@ def sample_n_random(items: List[str], n: int) -> List[str]: return items return random.sample(items, n) - def print_messages(system_message_text: str, user_message_text: str) -> None: print(f"""----------------------- SYSTEM MESSAGE -----------------------) {system_message_text} @@ -207,9 +202,7 @@ def print_messages(system_message_text: str, user_message_text: str) -> None: ---------------------------------------------- """, flush=True) - def print_generated_func_body(func_body_str: str) -> None: print(f"""--------------------- GENERATED FUNC BODY --------------------- {func_body_str} ------------------------------------------""") - diff --git a/programming_runs/generators/py_generate.py b/programming_runs/generators/py_generate.py index 82a633e..34d4f7c 100644 --- a/programming_runs/generators/py_generate.py +++ b/programming_runs/generators/py_generate.py @@ -1,4 +1,4 @@ -from generators.model import Message, ModelBase, messages_to_str +from generators.model import ModelBase, message_to_str from .generator_types import Generator from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection @@ -221,22 +221,24 @@ The implementation failed 4 out of the 7 test cases due to an IndexError. The is END OF EXAMPLES """ -PY_TEST_GENERATION_FEW_SHOT = [ - Message(role="user", content="""def add3Numbers(x, y, z): +PY_TEST_GENERATION_FEW_SHOT = """Examples: +func signature: +def add3Numbers(x, y, z): \"\"\" Add three numbers together. This function takes three numbers as input and returns the sum of the three numbers. - \"\"\""""), - Message(role="assistant", content="""assert add3Numbers(1, 2, 3) == 6 + \"\"\" +unit tests: +assert add3Numbers(1, 2, 3) == 6 assert add3Numbers(-1, 2, 3) == 4 assert add3Numbers(1, -2, 3) == 2 assert add3Numbers(1, 2, -3) == 0 assert add3Numbers(-3, -2, -1) == -6 -assert add3Numbers(0, 0, 0) == 0""") -] +assert add3Numbers(0, 0, 0) == 0 +""" PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. -{messages_to_str(PY_TEST_GENERATION_FEW_SHOT)}""" +{PY_TEST_GENERATION_FEW_SHOT}""" PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" diff --git a/programming_runs/generators/rs_generate.py b/programming_runs/generators/rs_generate.py index 7db258e..6bf15d8 100644 --- a/programming_runs/generators/rs_generate.py +++ b/programming_runs/generators/rs_generate.py @@ -1,4 +1,4 @@ -from generators.model import Message, ModelBase, messages_to_str +from generators.model import ModelBase from .generator_types import Generator from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection from .parse import parse_code_block, add_code_block @@ -47,18 +47,21 @@ fn add(a: i32, b: i32) -> i32 { END EXAMPLES ''' -RS_TEST_GENERATION_FEW_SHOT = [ - Message(role="user", content="""/// Add three numbers together. +RS_TEST_GENERATION_FEW_SHOT = """For example: + +func signature: +/// Add three numbers together. /// This function takes three numbers as input and returns the sum of the three numbers. fn add3Numbers(x: i32, y: i32, z: i32) -> i32 { -"""), - Message(role="assistant", content="""assert_eq!(add3Numbers(1, 2, 3), 6); + +unit tests: +assert_eq!(add3Numbers(1, 2, 3), 6); assert_eq!(add3Numbers(-1, 2, 3), 4); assert_eq!(add3Numbers(1, -2, 3), 2); assert_eq!(add3Numbers(1, 2, -3), 0); assert_eq!(add3Numbers(-3, -2, -1), -6); -assert_eq!(add3Numbers(0, 0, 0), 0);""") -] +assert_eq!(add3Numbers(0, 0, 0), 0); +""" RS_SELF_REFLECTION_FEW_SHOT = '''Example 1: [function impl]: