jan/web-app/src/lib/completion.ts
Louis ce9c8fe1cf
Update web-app/src/lib/completion.ts
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-08 20:59:36 +07:00

510 lines
14 KiB
TypeScript

/* eslint-disable @typescript-eslint/no-explicit-any */
import CryptoJS from 'crypto-js'
import {
ContentType,
ChatCompletionRole,
ThreadMessage,
MessageStatus,
EngineManager,
ModelManager,
chatCompletionRequestMessage,
chatCompletion,
chatCompletionChunk,
Tool,
} from '@janhq/core'
import { getServiceHub } from '@/hooks/useServiceHub'
import {
ChatCompletionMessageParam,
ChatCompletionTool,
CompletionResponse,
CompletionResponseChunk,
models,
StreamCompletionResponse,
TokenJS,
ConfigOptions,
} from 'token.js'
// Extended config options to include custom fetch function
type ExtendedConfigOptions = ConfigOptions & {
fetch?: typeof fetch
}
import { ulid } from 'ulidx'
import { MCPTool } from '@/types/completion'
import { CompletionMessagesBuilder } from './messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { ExtensionManager } from './extension'
import { useAppState } from '@/hooks/useAppState'
export type ChatCompletionResponse =
| chatCompletion
| AsyncIterable<chatCompletionChunk>
| StreamCompletionResponse
| CompletionResponse
/**
* @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects
* for different types of content, such as text and image.
* The functions return objects that conform to the `ThreadContent` type.
* @param content - The content of the thread
* @returns
*/
export const newUserThreadContent = (
threadId: string,
content: string,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
): ThreadMessage => {
const contentParts = [
{
type: ContentType.Text,
text: {
value: content,
annotations: [],
},
},
]
// Add attachments to content array
if (attachments) {
attachments.forEach((attachment) => {
if (attachment.type.startsWith('image/')) {
contentParts.push({
type: ContentType.Image,
image_url: {
url: `data:${attachment.type};base64,${attachment.base64}`,
detail: 'auto',
},
} as any)
}
})
}
return {
type: 'text',
role: ChatCompletionRole.User,
content: contentParts,
id: ulid(),
object: 'thread.message',
thread_id: threadId,
status: MessageStatus.Ready,
created_at: 0,
completed_at: 0,
}
}
/**
* @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects
* for different types of content, such as text and image.
* The functions return objects that conform to the `ThreadContent` type.
* @param content - The content of the thread
* @returns
*/
export const newAssistantThreadContent = (
threadId: string,
content: string,
metadata: Record<string, unknown> = {}
): ThreadMessage => ({
type: 'text',
role: ChatCompletionRole.Assistant,
content: [
{
type: ContentType.Text,
text: {
value: content,
annotations: [],
},
},
],
id: ulid(),
object: 'thread.message',
thread_id: threadId,
status: MessageStatus.Ready,
created_at: 0,
completed_at: 0,
metadata,
})
/**
* Empty thread content object.
* @returns
*/
export const emptyThreadContent: ThreadMessage = {
type: 'text',
role: ChatCompletionRole.Assistant,
id: ulid(),
object: 'thread.message',
thread_id: '',
content: [],
status: MessageStatus.Ready,
created_at: 0,
completed_at: 0,
}
/**
* @fileoverview Helper function to send a completion request to the model provider.
* @param thread
* @param provider
* @param messages
* @returns
*/
export const sendCompletion = async (
thread: Thread,
provider: ModelProvider,
messages: ChatCompletionMessageParam[],
abortController: AbortController,
tools: MCPTool[] = [],
stream: boolean = true,
params: Record<string, object> = {}
): Promise<ChatCompletionResponse | undefined> => {
if (!thread?.model?.id || !provider) return undefined
let providerName = provider.provider as unknown as keyof typeof models
if (!Object.keys(models).some((key) => key === providerName))
providerName = 'openai-compatible'
// Decrypt API key if it exists and is encrypted
const secretKey = await getServiceHub().core().getAppToken()
const decryptApiKey = (encryptedKey: string, key: string): string => {
try {
const bytes = CryptoJS.AES.decrypt(encryptedKey, key)
const decryptedKey = bytes.toString(CryptoJS.enc.Utf8)
if (!decryptedKey) {
throw new Error('Failed to decrypt API key: result is empty')
}
return decryptedKey
} catch (error) {
console.warn('Failed to decrypt API key:', error)
throw new Error('Failed to decrypt API key')
}
}
if (!secretKey) {
throw new Error('Encryption key unavailable: cannot decrypt API key.')
}
if (!provider.api_key) {
throw new Error('API key is missing for the selected provider.');
}
const apiKey = decryptApiKey(provider.api_key, secretKey);
const tokenJS = new TokenJS({
apiKey,
// TODO: Retrieve from extension settings
baseURL: provider.base_url,
// Use Tauri's fetch to avoid CORS issues only for openai-compatible provider
fetch: IS_DEV ? fetch : getServiceHub().providers().fetch(),
// OpenRouter identification headers for Jan
// ref: https://openrouter.ai/docs/api-reference/overview#headers
...(provider.provider === 'openrouter' && {
defaultHeaders: {
'HTTP-Referer': 'https://jan.ai',
'X-Title': 'Jan',
},
}),
// Add Origin header for local providers to avoid CORS issues
...((provider.base_url?.includes('localhost:') ||
provider.base_url?.includes('127.0.0.1:')) && {
fetch: getServiceHub().providers().fetch(),
defaultHeaders: {
Origin: 'tauri://localhost',
},
}),
} as ExtendedConfigOptions)
if (
thread.model.id &&
models[providerName]?.models !== true && // Skip if provider accepts any model (models: true)
!Object.values(models[providerName]).flat().includes(thread.model.id) &&
!tokenJS.extendedModelExist(providerName as any, thread.model.id) &&
provider.provider !== 'llamacpp'
) {
try {
tokenJS.extendModelList(
providerName as any,
thread.model.id,
// This is to inherit the model capabilities from another built-in model
// Can be anything that support all model capabilities
models.anthropic.models[0]
)
} catch (error) {
console.error(
`Failed to extend model list for ${providerName} with model ${thread.model.id}:`,
error
)
}
}
const engine = ExtensionManager.getInstance().getEngine(provider.provider)
const completion = engine
? await engine.chat(
{
messages: messages as chatCompletionRequestMessage[],
model: thread.model?.id,
thread_id: thread.id,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
stream: true,
...params,
},
abortController
)
: stream
? await tokenJS.chat.completions.create(
{
stream: true,
provider: providerName as any,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
...params,
},
{
signal: abortController.signal,
}
)
: await tokenJS.chat.completions.create({
stream: false,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
...params,
})
return completion
}
export const isCompletionResponse = (
response: ChatCompletionResponse
): response is CompletionResponse | chatCompletion => {
return 'choices' in response
}
/**
* @fileoverview Helper function to stop a model.
* This function unloads the model from the provider.
* @param provider
* @param model
* @returns
*/
export const stopModel = async (
provider: string,
model: string
): Promise<void> => {
const providerObj = EngineManager.instance().get(provider)
const modelObj = ModelManager.instance().get(model)
if (providerObj && modelObj) return providerObj?.unload(model).then(() => {})
}
/**
* @fileoverview Helper function to normalize tools for the chat completion request.
* This function converts the MCPTool objects to ChatCompletionTool objects.
* @param tools
* @returns
*/
export const normalizeTools = (
tools: MCPTool[]
): ChatCompletionTool[] | Tool[] | undefined => {
if (tools.length === 0) return undefined
return tools.map((tool) => ({
type: 'function',
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
}
/**
* @fileoverview Helper function to extract tool calls from the completion response.
* @param part
* @param calls
*/
export const extractToolCall = (
part: chatCompletionChunk | CompletionResponseChunk,
currentCall: ChatCompletionMessageToolCall | null,
calls: ChatCompletionMessageToolCall[]
) => {
const deltaToolCalls = part.choices[0].delta.tool_calls
// Handle the beginning of a new tool call
if (deltaToolCalls?.[0]?.index !== undefined && deltaToolCalls[0]?.function) {
const index = deltaToolCalls[0].index
// Create new tool call if this is the first chunk for it
if (!calls[index]) {
calls[index] = {
id: deltaToolCalls[0]?.id || ulid(),
function: {
name: deltaToolCalls[0]?.function?.name || '',
arguments: deltaToolCalls[0]?.function?.arguments || '',
},
type: 'function',
}
currentCall = calls[index]
} else {
// Continuation of existing tool call
currentCall = calls[index]
// Append to function name or arguments if they exist in this chunk
if (
deltaToolCalls[0]?.function?.name &&
currentCall!.function.name !== deltaToolCalls[0]?.function?.name
) {
currentCall!.function.name += deltaToolCalls[0].function.name
}
if (deltaToolCalls[0]?.function?.arguments) {
currentCall!.function.arguments += deltaToolCalls[0].function.arguments
}
}
}
return calls
}
/**
* @fileoverview Helper function to process the completion response.
* @param calls
* @param builder
* @param message
* @param abortController
* @param approvedTools
* @param showModal
* @param allowAllMCPPermissions
*/
export const postMessageProcessing = async (
calls: ChatCompletionMessageToolCall[],
builder: CompletionMessagesBuilder,
message: ThreadMessage,
abortController: AbortController,
approvedTools: Record<string, string[]> = {},
showModal?: (
toolName: string,
threadId: string,
toolParameters?: object
) => Promise<boolean>,
allowAllMCPPermissions: boolean = false
) => {
// Handle completed tool calls
if (calls.length) {
for (const toolCall of calls) {
if (abortController.signal.aborted) break
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...(toolCall as object),
id: toolId,
},
response: undefined,
state: 'pending',
},
],
}
// Check if tool is approved or show modal for approval
let toolParameters = {}
if (toolCall.function.arguments.length) {
try {
console.log('Raw tool arguments:', toolCall.function.arguments)
toolParameters = JSON.parse(toolCall.function.arguments)
console.log('Parsed tool parameters:', toolParameters)
} catch (error) {
console.error('Failed to parse tool arguments:', error)
console.error(
'Raw arguments that failed:',
toolCall.function.arguments
)
}
}
const approved =
allowAllMCPPermissions ||
approvedTools[message.thread_id]?.includes(toolCall.function.name) ||
(showModal
? await showModal(
toolCall.function.name,
message.thread_id,
toolParameters
)
: true)
const { promise, cancel } = getServiceHub()
.mcp()
.callToolWithCancellation({
toolName: toolCall.function.name,
arguments: toolCall.function.arguments.length ? toolParameters : {},
})
useAppState.getState().setCancelToolCall(cancel)
let result = approved
? await promise.catch((e) => {
console.error('Tool call failed:', e)
return {
content: [
{
type: 'text',
text: `Error calling tool ${toolCall.function.name}: ${e.message ?? e}`,
},
],
error: true,
}
})
: {
content: [
{
type: 'text',
text: 'The user has chosen to disallow the tool call.',
},
],
}
if (typeof result === 'string') {
result = {
content: [
{
type: 'text',
text: result,
},
],
}
}
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
builder.addToolMessage(result.content[0]?.text ?? '', toolCall.id)
// update message metadata
}
return message
}
}