about summary refs log tree commit diff stats
path: root/archive/2025/summer/bsc_gerg/tests/util.py
diff options
context:
space:
mode:
authorJonas Gerg <joniogerg@gmail.com>2025-09-09 20:06:52 +0200
committerJonas Gerg <joniogerg@gmail.com>2025-09-09 20:06:52 +0200
commit3e5d3ca82193e8e8561beb9ceac9982f376d84e2 (patch)
tree76e4c260123b68b93da2417482024ba11f9838ee /archive/2025/summer/bsc_gerg/tests/util.py
parenta910d0a3e57f4de47cf2387ac239ae8d0eaca507 (diff)
downloadresearch-work-archive-artifacts-3e5d3ca82193e8e8561beb9ceac9982f376d84e2.tar.gz
research-work-archive-artifacts-3e5d3ca82193e8e8561beb9ceac9982f376d84e2.zip
Add bsc_gerg
Diffstat (limited to 'archive/2025/summer/bsc_gerg/tests/util.py')
-rw-r--r--archive/2025/summer/bsc_gerg/tests/util.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/archive/2025/summer/bsc_gerg/tests/util.py b/archive/2025/summer/bsc_gerg/tests/util.py
new file mode 100644
index 000000000..b1da26375
--- /dev/null
+++ b/archive/2025/summer/bsc_gerg/tests/util.py
@@ -0,0 +1,59 @@
+import asyncio
+from typing import AsyncIterable, Tuple
+
+import dotenv
+from openai import OpenAI
+
+SEED = 42
+
+def collect_async(iterable: AsyncIterable):
+    """Synchronously collect all items in the AsyncIterable and return them as a list."""
+    async def do():
+        return [event async for event in iterable]
+    return asyncio.run(do())
+
+
+client: OpenAI | None = None
+def get_openai_client() -> OpenAI:
+    global client
+    if client is None:
+        dotenv.load_dotenv()
+        client = OpenAI()
+    return client
+
+
+def create_completion_openai_sync(
+                             messages: list[Tuple[str, str]],
+                             model: str = "gpt-4o-mini",
+                             temperature=0.0,
+                             max_completion_tokens=2048,
+                             top_p=0.0,
+                             frequency_penalty=0,
+                             presence_penalty=0,
+                             store=False,
+                            logprobs=False,
+                             ):
+    response = get_openai_client().chat.completions.create(
+        model=model,
+        messages=[
+            {
+                "role": role,
+                "content": prompt
+            } for role, prompt in messages
+        ],
+        response_format={"type": "text"},
+        temperature=temperature,
+        max_completion_tokens=max_completion_tokens,
+        top_p=top_p,
+        frequency_penalty=frequency_penalty,
+        presence_penalty=presence_penalty,
+        store=store,
+        logprobs=logprobs,
+        seed=SEED,
+        top_logprobs=20 if logprobs else None
+    )
+
+    if logprobs:
+        return response.choices[0].message.content, response.choices[0].logprobs
+    else:
+        return response.choices[0].message.content