diff options
| author | Martin Fink <martin@finkmartin.com> | 2025-09-11 09:19:48 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-11 09:19:48 +0200 |
| commit | 17af5f6fc0538f615b8612dcd2cb77c2affad63f (patch) | |
| tree | 76e4c260123b68b93da2417482024ba11f9838ee /archive/2025/summer/bsc_gerg/tests/util.py | |
| parent | a910d0a3e57f4de47cf2387ac239ae8d0eaca507 (diff) | |
| parent | 3e5d3ca82193e8e8561beb9ceac9982f376d84e2 (diff) | |
| download | research-work-archive-artifacts-17af5f6fc0538f615b8612dcd2cb77c2affad63f.tar.gz research-work-archive-artifacts-17af5f6fc0538f615b8612dcd2cb77c2affad63f.zip | |
Merge pull request #10 from walamana/main
Add bsc_gerg
Diffstat (limited to 'archive/2025/summer/bsc_gerg/tests/util.py')
| -rw-r--r-- | archive/2025/summer/bsc_gerg/tests/util.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/archive/2025/summer/bsc_gerg/tests/util.py b/archive/2025/summer/bsc_gerg/tests/util.py new file mode 100644 index 000000000..b1da26375 --- /dev/null +++ b/archive/2025/summer/bsc_gerg/tests/util.py @@ -0,0 +1,59 @@ +import asyncio +from typing import AsyncIterable, Tuple + +import dotenv +from openai import OpenAI + +SEED = 42 + +def collect_async(iterable: AsyncIterable): + """Synchronously collect all items in the AsyncIterable and return them as a list.""" + async def do(): + return [event async for event in iterable] + return asyncio.run(do()) + + +client: OpenAI | None = None +def get_openai_client() -> OpenAI: + global client + if client is None: + dotenv.load_dotenv() + client = OpenAI() + return client + + +def create_completion_openai_sync( + messages: list[Tuple[str, str]], + model: str = "gpt-4o-mini", + temperature=0.0, + max_completion_tokens=2048, + top_p=0.0, + frequency_penalty=0, + presence_penalty=0, + store=False, + logprobs=False, + ): + response = get_openai_client().chat.completions.create( + model=model, + messages=[ + { + "role": role, + "content": prompt + } for role, prompt in messages + ], + response_format={"type": "text"}, + temperature=temperature, + max_completion_tokens=max_completion_tokens, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + store=store, + logprobs=logprobs, + seed=SEED, + top_logprobs=20 if logprobs else None + ) + + if logprobs: + return response.choices[0].message.content, response.choices[0].logprobs + else: + return response.choices[0].message.content |