♻️ refactor: to follow prettier convention

This commit is contained in:
LazyYuuki 2025-06-15 18:54:22 +08:00
parent 665de7df55
commit 4b3a0918fe
2 changed files with 178 additions and 178 deletions

View File

@ -1,36 +1,36 @@
import { create } from "zustand";
import { ThreadMessage } from "@janhq/core";
import { MCPTool } from "@/types/completion";
import { useAssistant } from "./useAssistant";
import { ChatCompletionMessageToolCall } from "openai/resources";
import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core'
import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources'
type AppState = {
streamingContent?: ThreadMessage;
loadingModel?: boolean;
tools: MCPTool[];
serverStatus: "running" | "stopped" | "pending";
abortControllers: Record<string, AbortController>;
tokenSpeed?: TokenSpeed;
currentToolCall?: ChatCompletionMessageToolCall;
showOutOfContextDialog?: boolean;
setServerStatus: (value: "running" | "stopped" | "pending") => void;
updateStreamingContent: (content: ThreadMessage | undefined) => void;
streamingContent?: ThreadMessage
loadingModel?: boolean
tools: MCPTool[]
serverStatus: 'running' | 'stopped' | 'pending'
abortControllers: Record<string, AbortController>
tokenSpeed?: TokenSpeed
currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void
updateCurrentToolCall: (
toolCall: ChatCompletionMessageToolCall | undefined,
) => void;
updateLoadingModel: (loading: boolean) => void;
updateTools: (tools: MCPTool[]) => void;
setAbortController: (threadId: string, controller: AbortController) => void;
updateTokenSpeed: (message: ThreadMessage) => void;
resetTokenSpeed: () => void;
setOutOfContextDialog: (show: boolean) => void;
};
toolCall: ChatCompletionMessageToolCall | undefined
) => void
updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
setAbortController: (threadId: string, controller: AbortController) => void
updateTokenSpeed: (message: ThreadMessage) => void
resetTokenSpeed: () => void
setOutOfContextDialog: (show: boolean) => void
}
export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined,
loadingModel: false,
tools: [],
serverStatus: "stopped",
serverStatus: 'stopped',
abortControllers: {},
tokenSpeed: undefined,
currentToolCall: undefined,
@ -46,19 +46,19 @@ export const useAppState = create<AppState>()((set) => ({
},
}
: undefined,
}));
console.log(useAppState.getState().streamingContent);
}))
console.log(useAppState.getState().streamingContent)
},
updateCurrentToolCall: (toolCall) => {
set(() => ({
currentToolCall: toolCall,
}));
}))
},
updateLoadingModel: (loading) => {
set({ loadingModel: loading });
set({ loadingModel: loading })
},
updateTools: (tools) => {
set({ tools });
set({ tools })
},
setServerStatus: (value) => set({ serverStatus: value }),
setAbortController: (threadId, controller) => {
@ -67,11 +67,11 @@ export const useAppState = create<AppState>()((set) => ({
...state.abortControllers,
[threadId]: controller,
},
}));
}))
},
updateTokenSpeed: (message) =>
set((state) => {
const currentTimestamp = new Date().getTime(); // Get current time in milliseconds
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!state.tokenSpeed) {
// If this is the first update, just set the lastTimestamp and return
return {
@ -81,14 +81,14 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: 1,
message: message.id,
},
};
}
}
const timeDiffInSeconds =
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000; // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + 1;
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + 1
const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1); // Calculate average token speed
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
return {
tokenSpeed: {
...state.tokenSpeed,
@ -96,7 +96,7 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: totalTokenCount,
message: message.id,
},
};
}
}),
resetTokenSpeed: () =>
set({
@ -105,6 +105,6 @@ export const useAppState = create<AppState>()((set) => ({
setOutOfContextDialog: (show) => {
set(() => ({
showOutOfContextDialog: show,
}));
}))
},
}));
}))

View File

@ -1,12 +1,12 @@
import { useCallback, useEffect, useMemo } from "react";
import { usePrompt } from "./usePrompt";
import { useModelProvider } from "./useModelProvider";
import { useThreads } from "./useThreads";
import { useAppState } from "./useAppState";
import { useMessages } from "./useMessages";
import { useRouter } from "@tanstack/react-router";
import { defaultModel } from "@/lib/models";
import { route } from "@/constants/routes";
import { useCallback, useEffect, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
@ -15,23 +15,23 @@ import {
newUserThreadContent,
postMessageProcessing,
sendCompletion,
} from "@/lib/completion";
import { CompletionMessagesBuilder } from "@/lib/messages";
import { ChatCompletionMessageToolCall } from "openai/resources";
import { useAssistant } from "./useAssistant";
import { toast } from "sonner";
import { getTools } from "@/services/mcp";
import { MCPTool } from "@/types/completion";
import { listen } from "@tauri-apps/api/event";
import { SystemEvent } from "@/types/events";
import { stopModel, startModel, stopAllModels } from "@/services/models";
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { toast } from 'sonner'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models'
import { useToolApproval } from "@/hooks/useToolApproval";
import { useToolAvailable } from "@/hooks/useToolAvailable";
import { OUT_OF_CONTEXT_SIZE } from "@/utils/error";
import { useToolApproval } from '@/hooks/useToolApproval'
import { useToolAvailable } from '@/hooks/useToolAvailable'
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt();
const { prompt, setPrompt } = usePrompt()
const {
tools,
updateTokenSpeed,
@ -40,51 +40,51 @@ export const useChat = () => {
updateStreamingContent,
updateLoadingModel,
setAbortController,
} = useAppState();
const { currentAssistant } = useAssistant();
const { updateProvider } = useModelProvider();
} = useAppState()
const { currentAssistant } = useAssistant()
const { updateProvider } = useModelProvider()
const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
useToolApproval();
const { getDisabledToolsForThread } = useToolAvailable();
useToolApproval()
const { getDisabledToolsForThread } = useToolAvailable()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider();
useModelProvider()
const {
getCurrentThread: retrieveThread,
createThread,
updateThreadTimestamp,
} = useThreads();
const { getMessages, addMessage } = useMessages();
const router = useRouter();
} = useThreads()
const { getMessages, addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider);
}, [selectedProvider, getProviderByName]);
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const currentProviderId = useMemo(() => {
return provider?.provider || selectedProvider;
}, [provider, selectedProvider]);
return provider?.provider || selectedProvider
}, [provider, selectedProvider])
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data);
});
updateTools(data)
})
}
setTools();
setTools()
let unsubscribe = () => {};
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub;
});
return unsubscribe;
}, [updateTools]);
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread();
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
@ -92,14 +92,14 @@ export const useChat = () => {
provider: selectedProvider,
},
prompt,
currentAssistant,
);
currentAssistant
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
});
})
}
return currentThread;
return currentThread
}, [
createThread,
prompt,
@ -108,7 +108,7 @@ export const useChat = () => {
selectedModel?.id,
selectedProvider,
currentAssistant,
]);
])
const increaseModelContextSize = useCallback(
(model: Model, provider: ProviderObject) => {
@ -118,12 +118,12 @@ export const useChat = () => {
*/
const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === "string"
? typeof model.settings.ctx_len.controller_props.value === 'string'
? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number)
: 8192,
8192,
);
8192
)
const updatedModel = {
...model,
settings: {
@ -136,80 +136,80 @@ export const useChat = () => {
},
},
},
};
}
// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id);
const modelIndex = provider.models.findIndex((m) => m.id === model.id)
if (modelIndex !== -1) {
// Create a copy of the provider's models array
const updatedModels = [...provider.models];
const updatedModels = [...provider.models]
// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model;
updatedModels[modelIndex] = updatedModel as Model
// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
});
})
}
stopAllModels();
stopAllModels()
},
[updateProvider],
);
[updateProvider]
)
const sendMessage = useCallback(
async (
message: string,
showModal?: () => Promise<unknown>,
troubleshooting = true,
troubleshooting = true
) => {
const activeThread = await getCurrentThread();
const activeThread = await getCurrentThread()
resetTokenSpeed();
resetTokenSpeed()
const activeProvider = currentProviderId
? getProviderByName(currentProviderId)
: provider;
if (!activeThread || !activeProvider) return;
const messages = getMessages(activeThread.id);
const abortController = new AbortController();
setAbortController(activeThread.id, abortController);
updateStreamingContent(emptyThreadContent);
: provider
if (!activeThread || !activeProvider) return
const messages = getMessages(activeThread.id)
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message));
updateThreadTimestamp(activeThread.id);
setPrompt("");
addMessage(newUserThreadContent(activeThread.id, message))
updateThreadTimestamp(activeThread.id)
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true);
updateLoadingModel(true)
await startModel(
activeProvider,
selectedModel.id,
abortController,
).catch(console.error);
updateLoadingModel(false);
abortController
).catch(console.error)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant?.instructions,
);
currentAssistant?.instructions
)
builder.addUserMessage(message);
builder.addUserMessage(message)
let isCompleted = false;
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes("tools")
let availableTools = selectedModel?.capabilities?.includes('tools')
? tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id);
return !disabledTools.includes(tool.name);
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
})
: [];
: []
// TODO: Later replaced by Agent setup?
const followUpWithToolUse = true;
const followUpWithToolUse = true
while (!isCompleted && !abortController.signal.aborted) {
const completion = await sendCompletion(
activeThread,
@ -218,51 +218,51 @@ export const useChat = () => {
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object>,
currentAssistant.parameters as unknown as Record<string, object>
// TODO: replace it with according provider setting later on
// selectedProvider === 'llama.cpp' && availableTools.length > 0
// ? false
// : true
);
)
if (!completion) throw new Error("No completion received");
let accumulatedText = "";
const currentCall: ChatCompletionMessageToolCall | null = null;
const toolCalls: ChatCompletionMessageToolCall[] = [];
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || "";
accumulatedText = completion.choices[0]?.message?.content || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls);
toolCalls.push(...completion.choices[0].message.tool_calls)
}
} else {
for await (const part of completion) {
// Error message
if (!part.choices) {
throw new Error(
"message" in part
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? ""),
);
: (JSON.stringify(part) ?? '')
)
}
const delta = part.choices[0]?.delta?.content || "";
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
const calls = extractToolCall(part, currentCall, toolCalls);
const calls = extractToolCall(part, currentCall, toolCalls)
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: calls.map((e) => ({
...e,
state: "pending",
state: 'pending',
})),
},
);
updateStreamingContent(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0));
}
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
if (delta) {
accumulatedText += delta;
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
@ -271,13 +271,13 @@ export const useChat = () => {
{
tool_calls: toolCalls.map((e) => ({
...e,
state: "pending",
state: 'pending',
})),
},
);
updateStreamingContent(currentContent);
updateTokenSpeed(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0));
}
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
}
@ -286,10 +286,10 @@ export const useChat = () => {
accumulatedText.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider.provider === "llama.cpp"
activeProvider.provider === 'llama.cpp'
) {
await stopModel(activeThread.model.id, "cortex");
throw new Error("No response received from the model");
await stopModel(activeThread.model.id, 'cortex')
throw new Error('No response received from the model')
}
// Create a final content object for adding to the thread
@ -298,10 +298,10 @@ export const useChat = () => {
accumulatedText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
},
);
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls);
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
@ -309,41 +309,41 @@ export const useChat = () => {
abortController,
approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions,
);
addMessage(updatedMessage ?? finalContent);
updateStreamingContent(emptyThreadContent);
updateThreadTimestamp(activeThread.id);
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length;
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = [];
if (!followUpWithToolUse) availableTools = []
}
} catch (error) {
const errorMessage =
error && typeof error === "object" && "message" in error
error && typeof error === 'object' && 'message' in error
? error.message
: error;
: error
if (
typeof errorMessage === "string" &&
typeof errorMessage === 'string' &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel &&
troubleshooting
) {
showModal?.().then((confirmed) => {
if (confirmed) {
increaseModelContextSize(selectedModel, activeProvider);
increaseModelContextSize(selectedModel, activeProvider)
setTimeout(() => {
sendMessage(message, showModal, false); // Retry sending the message without troubleshooting
}, 1000);
sendMessage(message, showModal, false) // Retry sending the message without troubleshooting
}, 1000)
}
});
})
}
toast.error(`Error sending message: ${errorMessage}`);
console.error("Error sending message:", error);
toast.error(`Error sending message: ${errorMessage}`)
console.error('Error sending message:', error)
} finally {
updateLoadingModel(false);
updateStreamingContent(undefined);
updateLoadingModel(false)
updateStreamingContent(undefined)
}
},
[
@ -368,8 +368,8 @@ export const useChat = () => {
showApprovalModal,
updateTokenSpeed,
increaseModelContextSize,
],
);
]
)
return { sendMessage };
};
return { sendMessage }
}