File size: 2,663 Bytes
886ab52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import { streamText } from "ai";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import {
  getSettings,
  getTextGenerationState,
  updateResponse,
  updateTextGenerationState,
} from "./pubSub";
import {
  canStartResponding,
  getDefaultChatCompletionCreateParamsStreaming,
  getDefaultChatMessages,
  getFormattedSearchResults,
} from "./textGenerationUtilities";

let currentAbortController: AbortController | null = null;

interface StreamOptions {
  messages: ChatMessage[];
  onUpdate: (text: string) => void;
}

async function createOpenAiStream({
  messages,
  onUpdate,
}: StreamOptions): Promise<string> {
  const settings = getSettings();
  const openaiProvider = createOpenAICompatible({
    name: settings.openAiApiBaseUrl,
    baseURL: settings.openAiApiBaseUrl,
    apiKey: settings.openAiApiKey,
  });

  const params = getDefaultChatCompletionCreateParamsStreaming();

  try {
    currentAbortController = new AbortController();

    const stream = streamText({
      model: openaiProvider.chatModel(settings.openAiApiModel),
      messages: messages.map((msg) => ({
        role: msg.role || "user",
        content: msg.content,
      })),
      maxTokens: params.max_tokens,
      temperature: params.temperature,
      topP: params.top_p,
      frequencyPenalty: params.frequency_penalty,
      presencePenalty: params.presence_penalty,
      abortSignal: currentAbortController.signal,
    });

    let text = "";
    for await (const part of stream.fullStream) {
      if (getTextGenerationState() === "interrupted") {
        currentAbortController.abort();
        throw new Error("Chat generation interrupted");
      }

      if (part.type === "text-delta") {
        text += part.textDelta;
        onUpdate(text);
      }
    }

    return text;
  } catch (error) {
    if (
      getTextGenerationState() === "interrupted" ||
      (error instanceof DOMException && error.name === "AbortError")
    ) {
      throw new Error("Chat generation interrupted");
    }
    throw error;
  } finally {
    currentAbortController = null;
  }
}

export async function generateTextWithOpenAi() {
  await canStartResponding();
  updateTextGenerationState("preparingToGenerate");

  const messages = getDefaultChatMessages(getFormattedSearchResults(true));
  updateTextGenerationState("generating");

  await createOpenAiStream({
    messages,
    onUpdate: updateResponse,
  });
}

export async function generateChatWithOpenAi(
  messages: ChatMessage[],
  onUpdate: (partialResponse: string) => void,
) {
  return createOpenAiStream({ messages, onUpdate });
}