feat: adding model params (#886)

* feat: adding model params

Signed-off-by: James <james@jan.ai>

* chore: inference request parameter

* Improve ui right panel model params

* Remove unused import

* Update slider track for darkmode

---------

Signed-off-by: James <james@jan.ai>
Co-authored-by: James <james@jan.ai>
Co-authored-by: Louis <louis@jan.ai>
Co-authored-by: Faisal Amir <urmauur@gmail.com>
This commit is contained in:
NamH 2023-12-11 23:58:24 +07:00 committed by GitHub
parent 5f7001d294
commit 121dc119f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 765 additions and 203 deletions

View File

@ -71,9 +71,9 @@ export type ThreadMessage = {
object: string;
/** Thread id, default is a ulid. **/
thread_id: string;
/** The role of the author of this message. **/
/** The assistant id of this thread. **/
assistant_id?: string;
// TODO: comment
/** The role of the author of this message. **/
role: ChatCompletionRole;
/** The content of this message. **/
content: ThreadContent[];
@ -125,8 +125,6 @@ export interface Thread {
title: string;
/** Assistants in this thread. **/
assistants: ThreadAssistantInfo[];
// if the thread has been init will full assistant info
isFinishInit: boolean;
/** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/
created: number;
/** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/
@ -166,6 +164,7 @@ export type ThreadState = {
waitingForResponse: boolean;
error?: Error;
lastMessage?: string;
isFinishInit?: boolean;
};
/**
* Represents the inference engine.
@ -291,6 +290,9 @@ export type ModelRuntimeParams = {
top_p?: number;
stream?: boolean;
max_tokens?: number;
stop?: string[];
frequency_penalty?: number;
presence_penalty?: number;
};
/**

View File

@ -22,10 +22,6 @@ export default class JanAssistantExtension implements AssistantExtension {
onUnload(): void {}
async createAssistant(assistant: Assistant): Promise<void> {
// assuming that assistants/ directory is already created in the onLoad above
// TODO: check if the directory already exists, then ignore creation for now
const assistantDir = join(JanAssistantExtension._homeDir, assistant.id);
await fs.mkdir(assistantDir);
@ -91,7 +87,7 @@ export default class JanAssistantExtension implements AssistantExtension {
avatar: "",
thread_location: undefined,
id: "jan",
object: "assistant", // TODO: maybe we can set default value for this?
object: "assistant",
created_at: Date.now(),
name: "Jan",
description: "A default assistant that can use all downloaded models",

View File

@ -7,7 +7,6 @@ import { Observable } from "rxjs";
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
@ -23,34 +22,41 @@ export function requestInference(
headers: {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
Accept: "text/event-stream",
Accept: model.parameters.stream
? "text/event-stream"
: "application/json",
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
if (model.parameters.stream) {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
subscriber.next(content);
}
}
} else {
const data = await response.json();
subscriber.next(data.choices[0]?.message?.content ?? "");
}
subscriber.complete();
})

View File

@ -85,7 +85,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
*/
onUnload(): void {}
private async writeDefaultEngineSettings() {
try {
const engineFile = join(
@ -164,7 +163,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
return new Promise(async (resolve, reject) => {
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel
).subscribe({
next: (_content) => {},
@ -210,8 +208,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel,
{ ...JanInferenceNitroExtension._currentModel, ...data.model },
instance.controller
).subscribe({
next: (content) => {

View File

@ -15,9 +15,9 @@ export function requestInference(
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
let model_id: string = model.id
if (engine.full_url.includes("openai.azure.com")){
model_id = engine.full_url.split("/")[5]
let model_id: string = model.id;
if (engine.full_url.includes("openai.azure.com")) {
model_id = engine.full_url.split("/")[5];
}
const requestBody = JSON.stringify({
messages: recentMessages,
@ -29,7 +29,9 @@ export function requestInference(
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
Accept: model.parameters.stream
? "text/event-stream"
: "application/json",
"Access-Control-Allow-Origin": "*",
Authorization: `Bearer ${engine.api_key}`,
"api-key": `${engine.api_key}`,
@ -38,28 +40,33 @@ export function requestInference(
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
if (model.parameters.stream) {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
subscriber.next(content);
}
}
} else {
const data = await response.json();
subscriber.next(data.choices[0]?.message?.content ?? "");
}
subscriber.complete();
})

View File

@ -25,6 +25,7 @@
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-slider": "^1.1.2",
"@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-switch": "^1.0.3",
"@radix-ui/react-toast": "^1.1.5",

View File

@ -11,3 +11,4 @@ export * from './modal'
export * from './command'
export * from './textarea'
export * from './select'
export * from './slider'

View File

@ -1,5 +1,5 @@
.input {
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-md border bg-transparent px-3 py-1 transition-colors;
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-lg border bg-transparent px-3 py-1 transition-colors;
@apply disabled:cursor-not-allowed disabled:opacity-50;
@apply focus-visible:ring-secondary focus-visible:outline-none focus-visible:ring-1;
@apply file:border-0 file:bg-transparent file:font-medium;

View File

@ -15,6 +15,7 @@
@import './command/styles.scss';
@import './textarea/styles.scss';
@import './select/styles.scss';
@import './slider/styles.scss';
.animate-spin {
animation: spin 1s linear infinite;

View File

@ -0,0 +1,25 @@
'use client'
import * as React from 'react'
import * as SliderPrimitive from '@radix-ui/react-slider'
import { twMerge } from 'tailwind-merge'
const Slider = React.forwardRef<
React.ElementRef<typeof SliderPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
>(({ className, ...props }, ref) => (
<SliderPrimitive.Root
ref={ref}
className={twMerge('slider', className)}
{...props}
>
<SliderPrimitive.Track className="slider-track">
<SliderPrimitive.Range className="slider-range" />
</SliderPrimitive.Track>
<SliderPrimitive.Thumb className="slider-thumb" />
</SliderPrimitive.Root>
))
Slider.displayName = SliderPrimitive.Root.displayName
export { Slider }

View File

@ -0,0 +1,15 @@
.slider {
@apply relative flex w-full touch-none select-none items-center;
&-track {
@apply relative h-1.5 w-full grow overflow-hidden rounded-full bg-gray-200 dark:bg-gray-800;
}
&-range {
@apply absolute h-full bg-blue-600;
}
&-thumb {
@apply border-primary/50 bg-background focus-visible:ring-ring block h-4 w-4 rounded-full border shadow transition-colors focus-visible:outline-none focus-visible:ring-1 disabled:pointer-events-none disabled:opacity-50;
}
}

View File

@ -32,7 +32,7 @@ export default function CardSidebar({
return (
<div
className={twMerge(
'flex w-full flex-col rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
'flex w-full flex-col overflow-hidden rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
)}
>
<div
@ -43,7 +43,7 @@ export default function CardSidebar({
>
<button
onClick={() => setShow(!show)}
className="flex w-full flex-1 items-center space-x-2 px-3 py-2"
className="flex w-full flex-1 items-center space-x-2 bg-zinc-200 px-3 py-2 dark:bg-zinc-600/10"
>
<ChevronDownIcon
className={twMerge(

View File

@ -0,0 +1,62 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { useEffect, useState } from 'react'
import { Switch } from '@janhq/uikit'
import { useAtomValue } from 'jotai'
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
type Props = {
name: string
title: string
checked: boolean
register: any
}
const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => {
const [currentChecked, setCurrentChecked] = useState<boolean>(checked)
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
useEffect(() => {
setCurrentChecked(checked)
}, [checked])
useEffect(() => {
updateSetting()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentChecked])
const updateValue = [name].reduce((accumulator, value) => {
return { ...accumulator, [value]: currentChecked }
}, {})
const updateSetting = () => {
return updateModelParameter(String(threadId), {
...activeModelParams,
...updateValue,
})
}
return (
<div className="flex justify-between">
<label>{title}</label>
<Switch
checked={currentChecked}
{...register(name)}
onCheckedChange={(e) => {
setCurrentChecked(e)
}}
/>
</div>
)
}
export default Checkbox

View File

@ -26,7 +26,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'
import { toGigabytes } from '@/utils/converter'
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
export const selectedModelAtom = atom<Model | undefined>(undefined)
@ -36,6 +36,7 @@ export default function DropdownListSidebar() {
const activeThread = useAtomValue(activeThreadAtom)
const [selected, setSelected] = useState<Model | undefined>()
const { setMainViewState } = useMainViewState()
const { activeModel, stateModel } = useActiveModel()
useEffect(() => {
@ -61,13 +62,22 @@ export default function DropdownListSidebar() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activeThread, activeModel, stateModel.loading])
const threadStates = useAtomValue(threadStatesAtom)
if (!activeThread) {
return null
}
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
const onValueSelected = (value: string) => {
setSelected(downloadedModels.filter((x) => x.id === value)[0])
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
}
return (
<Select
disabled={finishInit}
value={selected?.id}
onValueChange={(value) => {
setSelected(downloadedModels.filter((x) => x.id === value)[0])
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
}}
onValueChange={finishInit ? undefined : onValueSelected}
>
<SelectTrigger className="w-full">
<SelectValue placeholder="Choose model to start">
@ -86,22 +96,20 @@ export default function DropdownListSidebar() {
</div>
) : (
<SelectGroup>
{downloadedModels.map((x, i) => {
return (
<SelectItem
key={i}
value={x.id}
className={twMerge(x.id === selected?.id && 'bg-secondary')}
>
<div className="flex w-full justify-between">
<span className="line-clamp-1 block">{x.name}</span>
<span className="font-bold text-muted-foreground">
{toGigabytes(x.metadata.size)}
</span>
</div>
</SelectItem>
)
})}
{downloadedModels.map((x, i) => (
<SelectItem
key={i}
value={x.id}
className={twMerge(x.id === selected?.id && 'bg-secondary')}
>
<div className="flex w-full justify-between">
<span className="line-clamp-1 block">{x.name}</span>
<span className="font-bold text-muted-foreground">
{toGigabytes(x.metadata.size)}
</span>
</div>
</SelectItem>
))}
</SelectGroup>
)}
<div className="border-b border-border" />

View File

@ -12,7 +12,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'
import { showRightSideBarAtom } from '@/screens/Chat/Sidebar'
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const TopBar = () => {
const activeThread = useAtomValue(activeThreadAtom)

View File

@ -25,7 +25,7 @@ import {
import {
updateThreadWaitingForResponseAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom)

View File

@ -0,0 +1,101 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { useEffect, useState } from 'react'
import { Slider, Input } from '@janhq/uikit'
import { useAtomValue } from 'jotai'
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
type Props = {
name: string
title: string
min: number
max: number
step: number
value: number
register: any
}
const SliderRightPanel: React.FC<Props> = ({
name,
title,
min,
max,
step,
value,
register,
}) => {
const [currentValue, setCurrentValue] = useState<number>(value)
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
useEffect(() => {
setCurrentValue(value)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [value])
useEffect(() => {
updateSetting()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentValue])
const updateValue = [name].reduce((accumulator, value) => {
return { ...accumulator, [value]: currentValue }
}, {})
const updateSetting = () => {
return updateModelParameter(String(threadId), {
...activeModelParams,
...updateValue,
})
}
return (
<div className="flex flex-col">
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
<div className="flex items-center gap-x-4">
<div className="relative w-full">
<Slider
{...register(name, {
setValueAs: (v: any) => parseInt(v),
})}
value={[currentValue]}
onValueChange={async (e) => {
setCurrentValue(Number(e[0]))
await updateSetting()
}}
type="range"
min={min}
max={max}
step={step}
/>
<div className="relative mt-2 flex items-center justify-between text-gray-400">
<p className="text-sm">{min}</p>
<p className="absolute left-1/2 -translate-x-1/2 text-sm">
{max / 2}
</p>
<p className="text-sm">{max}</p>
</div>
</div>
<Input
className="-mt-4 h-8 w-16"
min={min}
max={max}
value={String(currentValue)}
onChange={async (e) => {
setCurrentValue(Number(e.target.value))
await updateSetting()
}}
/>
</div>
</div>
)
}
export default SliderRightPanel

View File

@ -9,7 +9,7 @@ import { atom } from 'jotai'
import {
getActiveThreadIdAtom,
updateThreadStateLastMessageAtom,
} from './Conversation.atom'
} from './Thread.atom'
/**
* Stores all chat messages for all threads
@ -76,15 +76,18 @@ export const addNewMessageAtom = atom(
}
)
export const deleteConversationMessage = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
export const deleteChatMessageAtom = atom(
null,
(get, set, threadId: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}
newData[threadId] = []
set(chatMessages, newData)
}
newData[id] = []
set(chatMessages, newData)
})
)
export const cleanConversationMessages = atom(null, (get, set, id: string) => {
export const cleanChatMessageAtom = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}

View File

@ -1,8 +1,13 @@
import { Thread, ThreadContent, ThreadState } from '@janhq/core'
import {
ModelRuntimeParams,
Thread,
ThreadContent,
ThreadState,
} from '@janhq/core'
import { atom } from 'jotai'
/**
* Stores the current active conversation id.
* Stores the current active thread id.
*/
const activeThreadIdAtom = atom<string | undefined>(undefined)
@ -10,7 +15,7 @@ export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom))
export const setActiveThreadIdAtom = atom(
null,
(_get, set, convoId: string | undefined) => set(activeThreadIdAtom, convoId)
(_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId)
)
export const waitingToSendMessage = atom<boolean | undefined>(undefined)
@ -20,47 +25,48 @@ export const waitingToSendMessage = atom<boolean | undefined>(undefined)
*/
export const threadStatesAtom = atom<Record<string, ThreadState>>({})
export const activeThreadStateAtom = atom<ThreadState | undefined>((get) => {
const activeConvoId = get(activeThreadIdAtom)
if (!activeConvoId) {
console.debug('Active convo id is undefined')
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}
return get(threadStatesAtom)[activeConvoId]
return get(threadStatesAtom)[threadId]
})
export const deleteThreadStateAtom = atom(
null,
(get, set, threadId: string) => {
const currentState = { ...get(threadStatesAtom) }
delete currentState[threadId]
set(threadStatesAtom, currentState)
}
)
export const updateThreadInitSuccessAtom = atom(
null,
(get, set, threadId: string) => {
const currentState = { ...get(threadStatesAtom) }
currentState[threadId] = {
...currentState[threadId],
isFinishInit: true,
}
set(threadStatesAtom, currentState)
}
)
export const updateThreadWaitingForResponseAtom = atom(
null,
(get, set, conversationId: string, waitingForResponse: boolean) => {
(get, set, threadId: string, waitingForResponse: boolean) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
currentState[threadId] = {
...currentState[threadId],
waitingForResponse,
error: undefined,
}
set(threadStatesAtom, currentState)
}
)
export const updateConversationErrorAtom = atom(
null,
(get, set, conversationId: string, error?: Error) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
error,
}
set(threadStatesAtom, currentState)
}
)
export const updateConversationHasMoreAtom = atom(
null,
(get, set, conversationId: string, hasMore: boolean) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = { ...currentState[conversationId], hasMore }
set(threadStatesAtom, currentState)
}
)
export const updateThreadStateLastMessageAtom = atom(
null,
(get, set, threadId: string, lastContent?: ThreadContent[]) => {
@ -100,3 +106,42 @@ export const threadsAtom = atom<Thread[]>([])
export const activeThreadAtom = atom<Thread | undefined>((get) =>
get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
)
/**
* Store model params at thread level settings
*/
export const threadModelRuntimeParamsAtom = atom<
Record<string, ModelRuntimeParams>
>({})
export const getActiveThreadModelRuntimeParamsAtom = atom<
ModelRuntimeParams | undefined
>((get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}
return get(threadModelRuntimeParamsAtom)[threadId]
})
export const getThreadModelRuntimeParamsAtom = atom(
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
)
export const setThreadModelRuntimeParamsAtom = atom(
null,
(get, set, threadId: string, params: ModelRuntimeParams) => {
const currentState = { ...get(threadModelRuntimeParamsAtom) }
currentState[threadId] = params
console.debug(
`Update model params for thread ${threadId}, ${JSON.stringify(
params,
null,
2
)}`
)
set(threadModelRuntimeParamsAtom, currentState)
}
)

View File

@ -1,8 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
EventName,
events,
} from '@janhq/core'
import { EventName, events } from '@janhq/core'
import { Model, ModelSettingParams } from '@janhq/core'
import { atom, useAtom } from 'jotai'

View File

@ -6,9 +6,9 @@ import {
ThreadAssistantInfo,
ThreadState,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { generateThreadId } from '@/utils/conversation'
import { generateThreadId } from '@/utils/thread'
import { extensionManager } from '@/extension'
import {
@ -16,7 +16,8 @@ import {
setActiveThreadIdAtom,
threadStatesAtom,
updateThreadAtom,
} from '@/helpers/atoms/Conversation.atom'
setThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
// create thread state for this new thread
@ -25,6 +26,8 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
const threadState: ThreadState = {
hasMore: false,
waitingForResponse: false,
lastMessage: undefined,
isFinishInit: false,
}
currentState[newThread.id] = threadState
set(threadStatesAtom, currentState)
@ -35,15 +38,26 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
})
export const useCreateNewThread = () => {
const threadStates = useAtomValue(threadStatesAtom)
const createNewThread = useSetAtom(createNewThreadAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const [threadStates, setThreadStates] = useAtom(threadStatesAtom)
const threads = useAtomValue(threadsAtom)
const updateThread = useSetAtom(updateThreadAtom)
const setThreadModelRuntimeParams = useSetAtom(
setThreadModelRuntimeParamsAtom
)
const requestCreateNewThread = async (assistant: Assistant) => {
const unfinishedThreads = threads.filter((t) => t.isFinishInit === false)
if (unfinishedThreads.length > 0) {
// loop through threads state and filter if there's any thread that is not finish init
let hasUnfinishedInitThread = false
for (const key in threadStates) {
const isFinishInit = threadStates[key].isFinishInit ?? true
if (!isFinishInit) {
hasUnfinishedInitThread = true
break
}
}
if (hasUnfinishedInitThread) {
return
}
@ -53,18 +67,10 @@ export const useCreateNewThread = () => {
assistant_name: assistant.name,
model: {
id: '*',
settings: {
ctx_len: 0,
ngl: 0,
embedding: false,
n_parallel: 0,
},
settings: {},
parameters: {
temperature: 0,
token_limit: 0,
top_k: 0,
top_p: 0,
stream: false,
stream: true,
max_tokens: 1024,
},
engine: undefined,
},
@ -78,29 +84,20 @@ export const useCreateNewThread = () => {
assistants: [assistantInfo],
created: createdAt,
updated: createdAt,
isFinishInit: false,
}
// TODO: move isFinishInit here
const threadState: ThreadState = {
hasMore: false,
waitingForResponse: false,
lastMessage: undefined,
}
setThreadStates({ ...threadStates, [threadId]: threadState })
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)
// add the new thread on top of the thread list to the state
createNewThread(thread)
setActiveThreadId(thread.id)
}
function updateThreadMetadata(thread: Thread) {
const updatedThread: Thread = {
...thread,
}
updateThread(updatedThread)
updateThread(thread)
extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.saveThread(updatedThread)
?.saveThread(thread)
}
return {

View File

@ -11,14 +11,15 @@ import { useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager'
import {
cleanConversationMessages,
deleteConversationMessage,
cleanChatMessageAtom as cleanChatMessagesAtom,
deleteChatMessageAtom as deleteChatMessagesAtom,
getCurrentChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import {
threadsAtom,
setActiveThreadIdAtom,
} from '@/helpers/atoms/Conversation.atom'
deleteThreadStateAtom,
} from '@/helpers/atoms/Thread.atom'
export default function useDeleteThread() {
const { activeModel } = useActiveModel()
@ -26,45 +27,51 @@ export default function useDeleteThread() {
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const messages = useAtomValue(getCurrentChatMessagesAtom)
const setActiveConvoId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteConversationMessage)
const cleanMessages = useSetAtom(cleanConversationMessages)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
const cleanThread = async (threadId: string) => {
if (threadId) {
const thread = threads.filter((c) => c.id === threadId)[0]
cleanMessages(threadId)
const cleanThread = async (activeThreadId: string) => {
if (activeThreadId) {
const thread = threads.filter((c) => c.id === activeThreadId)[0]
cleanMessages(activeThreadId)
if (thread)
await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.writeMessages(
activeThreadId,
threadId,
messages.filter((msg) => msg.role === ChatCompletionRole.System)
)
}
}
const deleteThread = async (activeThreadId: string) => {
if (!activeThreadId) {
const deleteThread = async (threadId: string) => {
if (!threadId) {
alert('No active thread')
return
}
try {
await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.deleteThread(activeThreadId)
const availableThreads = threads.filter((c) => c.id !== activeThreadId)
?.deleteThread(threadId)
const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads)
deleteMessages(activeThreadId)
// delete the thread state
deleteThreadState(threadId)
deleteMessages(threadId)
setCurrentPrompt('')
toaster({
title: 'Thread successfully deleted.',
description: `Thread with ${activeModel?.name} has been successfully deleted.`,
})
if (availableThreads.length > 0) {
setActiveConvoId(availableThreads[0].id)
setActiveThreadId(availableThreads[0].id)
} else {
setActiveConvoId(undefined)
setActiveThreadId(undefined)
}
} catch (err) {
console.error(err)

View File

@ -1,35 +1,50 @@
import { ExtensionType, ThreadState } from '@janhq/core'
import { ExtensionType, ModelRuntimeParams, ThreadState } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension/ExtensionManager'
import {
threadModelRuntimeParamsAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
const useGetAllThreads = () => {
const setConversationStates = useSetAtom(threadStatesAtom)
const setConversations = useSetAtom(threadsAtom)
const setThreadStates = useSetAtom(threadStatesAtom)
const setThreads = useSetAtom(threadsAtom)
const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom)
const getAllThreads = async () => {
try {
const threads = await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()
const threads =
(await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()) ?? []
const threadStates: Record<string, ThreadState> = {}
threads?.forEach((thread) => {
const threadModelParams: Record<string, ModelRuntimeParams> = {}
threads.forEach((thread) => {
if (thread.id != null) {
const lastMessage = (thread.metadata?.lastMessage as string) ?? ''
threadStates[thread.id] = {
hasMore: true,
waitingForResponse: false,
lastMessage,
isFinishInit: true,
}
// model params
const modelParams = thread.assistants?.[0]?.model?.parameters
threadModelParams[thread.id] = modelParams
}
})
setConversationStates(threadStates)
setConversations(threads ?? [])
// updating app states
setThreadStates(threadStates)
setThreads(threads)
setThreadModelRuntimeParams(threadModelParams)
} catch (error) {
console.error(error)
}

View File

@ -19,9 +19,6 @@ export function useGetConfiguredModels() {
async function fetchModels() {
setLoading(true)
const models = await getConfiguredModels()
if (process.env.NODE_ENV === 'development') {
// models = [dummyModel, ...models] // TODO: NamH add back dummy model later
}
setLoading(false)
setModels(models)
}

View File

@ -12,8 +12,9 @@ import {
ThreadMessage,
events,
Model,
ConversationalExtension,
ModelRuntimeParams,
} from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulid'
@ -32,9 +33,12 @@ import {
} from '@/helpers/atoms/ChatMessage.atom'
import {
activeThreadAtom,
getActiveThreadModelRuntimeParamsAtom,
threadStatesAtom,
updateThreadAtom,
updateThreadInitSuccessAtom,
updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
export default function useSendChatMessage() {
const activeThread = useAtomValue(activeThreadAtom)
@ -50,6 +54,10 @@ export default function useSendChatMessage() {
const [queuedMessage, setQueuedMessage] = useState(false)
const modelRef = useRef<Model | undefined>()
const threadStates = useAtomValue(threadStatesAtom)
const updateThreadInitSuccess = useSetAtom(updateThreadInitSuccessAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
@ -109,7 +117,7 @@ export default function useSendChatMessage() {
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (modelRef.current?.id !== modelId) {
console.log('waiting for model to start')
console.debug('waiting for model to start')
await WaitForModelStarting(modelId)
resolve()
} else {
@ -127,8 +135,10 @@ export default function useSendChatMessage() {
console.error('No active thread')
return
}
const activeThreadState = threadStates[activeThread.id]
if (!activeThread.isFinishInit) {
// if the thread is not initialized, we need to initialize it first
if (!activeThreadState.isFinishInit) {
if (!selectedModel) {
toaster({ title: 'Please select a model' })
return
@ -136,9 +146,14 @@ export default function useSendChatMessage() {
const assistantId = activeThread.assistants[0].assistant_id ?? ''
const assistantName = activeThread.assistants[0].assistant_name ?? ''
const instructions = activeThread.assistants[0].instructions ?? ''
const modelParams: ModelRuntimeParams = {
...selectedModel.parameters,
...activeModelParams,
}
const updatedThread: Thread = {
...activeThread,
isFinishInit: true,
assistants: [
{
assistant_id: assistantId,
@ -147,13 +162,13 @@ export default function useSendChatMessage() {
model: {
id: selectedModel.id,
settings: selectedModel.settings,
parameters: selectedModel.parameters,
parameters: modelParams,
engine: selectedModel.engine,
},
},
],
}
updateThreadInitSuccess(activeThread.id)
updateThread(updatedThread)
extensionManager
@ -191,11 +206,16 @@ export default function useSendChatMessage() {
])
)
const msgId = ulid()
const modelRequest = selectedModel ?? activeThread.assistants[0].model
const messageRequest: MessageRequest = {
id: msgId,
threadId: activeThread.id,
messages,
model: selectedModel ?? activeThread.assistants[0].model,
model: {
...modelRequest,
...(activeModelParams ? { parameters: activeModelParams } : {}),
},
}
const timestamp = Date.now()
const threadMessage: ThreadMessage = {

View File

@ -9,7 +9,7 @@ import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import {
getActiveThreadIdAtom,
setActiveThreadIdAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
export default function useSetActiveThread() {
const activeThreadId = useAtomValue(getActiveThreadIdAtom)

View File

@ -0,0 +1,66 @@
import {
ConversationalExtension,
ExtensionType,
ModelRuntimeParams,
Thread,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { extensionManager } from '@/extension'
import {
activeThreadStateAtom,
setThreadModelRuntimeParamsAtom,
threadsAtom,
updateThreadAtom,
} from '@/helpers/atoms/Thread.atom'
export default function useUpdateModelParameters() {
const threads = useAtomValue(threadsAtom)
const updateThread = useSetAtom(updateThreadAtom)
const setThreadModelRuntimeParams = useSetAtom(
setThreadModelRuntimeParamsAtom
)
const activeThreadState = useAtomValue(activeThreadStateAtom)
const updateModelParameter = async (
threadId: string,
params: ModelRuntimeParams
) => {
const thread = threads.find((thread) => thread.id === threadId)
if (!thread) {
console.error(`Thread ${threadId} not found`)
return
}
if (!activeThreadState) {
console.error('No active thread')
return
}
// update the state
setThreadModelRuntimeParams(thread.id, params)
if (!activeThreadState.isFinishInit) {
// if thread is not initialized, we don't need to update thread.json
return
}
const assistants = thread.assistants.map((assistant) => {
assistant.model.parameters = params
return assistant
})
// update thread
const updatedThread: Thread = {
...thread,
assistants,
}
updateThread(updatedThread)
extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.saveThread(updatedThread)
}
return { updateModelParameter }
}

View File

@ -38,7 +38,6 @@
"sass": "^1.69.4",
"tailwind-merge": "^2.0.0",
"tailwindcss": "3.3.5",
"typescript": "5.2.2",
"ulid": "^2.3.0",
"uuid": "^9.0.1",
"zod": "^3.22.4"

View File

@ -21,7 +21,7 @@ import {
deleteMessageAtom,
getCurrentChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
const deleteMessage = useSetAtom(deleteMessageAtom)

View File

@ -0,0 +1,60 @@
import { useEffect, useState } from 'react'
import { useForm } from 'react-hook-form'
import { ModelRuntimeParams } from '@janhq/core'
import { useAtomValue } from 'jotai'
import { presetConfiguration } from './predefinedComponent'
import settingComponentBuilder, {
SettingComponentData,
} from './settingComponentBuilder'
import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
export default function ModelSetting() {
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
const [modelParams, setModelParams] = useState<
ModelRuntimeParams | undefined
>(activeModelParams)
const { register } = useForm()
useEffect(() => {
setModelParams(activeModelParams)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [threadId])
if (!modelParams) {
return <div>This thread has no model parameters</div>
}
const componentData: SettingComponentData[] = []
Object.keys(modelParams).forEach((key) => {
const componentSetting = presetConfiguration[key]
if (componentSetting) {
if ('value' in componentSetting.controllerData) {
componentSetting.controllerData.value = Number(
modelParams[key as keyof ModelRuntimeParams]
)
} else if ('checked' in componentSetting.controllerData) {
componentSetting.controllerData.checked = modelParams[
key as keyof ModelRuntimeParams
] as boolean
}
componentData.push(componentSetting)
}
})
return (
<form className="flex flex-col">
{settingComponentBuilder(componentData, register)}
</form>
)
}

View File

@ -0,0 +1,59 @@
import { SettingComponentData } from './settingComponentBuilder'
export const presetConfiguration: Record<string, SettingComponentData> = {
max_tokens: {
name: 'max_tokens',
title: 'Max Tokens',
description: 'Maximum context length the model can handle.',
controllerType: 'slider',
controllerData: {
min: 0,
max: 4096,
step: 128,
value: 2048,
},
},
ngl: {
name: 'ngl',
title: 'NGL',
description: 'Number of layers in the neural network.',
controllerType: 'slider',
controllerData: {
min: 1,
max: 100,
step: 1,
value: 100,
},
},
embedding: {
name: 'embedding',
title: 'Embedding',
description: 'Indicates if embedding layers are used.',
controllerType: 'checkbox',
controllerData: {
checked: true,
},
},
stream: {
name: 'stream',
title: 'Stream',
description: 'Stream',
controllerType: 'checkbox',
controllerData: {
checked: false,
},
},
temperature: {
name: 'temperature',
title: 'Temperature',
description:
"Controls randomness in model's responses. Higher values lead to more random responses.",
controllerType: 'slider',
controllerData: {
min: 0,
max: 2,
step: 0.1,
value: 0.7,
},
},
}

View File

@ -0,0 +1,67 @@
/* eslint-disable no-case-declarations */
/* eslint-disable @typescript-eslint/no-explicit-any */
import Checkbox from '@/containers/Checkbox'
import Slider from '@/containers/Slider'
export type ControllerType = 'slider' | 'checkbox'
export type SettingComponentData = {
name: string
title: string
description: string
controllerType: ControllerType
controllerData: SliderData | CheckboxData
}
export type SliderData = {
min: number
max: number
step: number
value: number
}
type CheckboxData = {
checked: boolean
}
const settingComponentBuilder = (
componentData: SettingComponentData[],
register: any
) => {
const components = componentData.map((data) => {
switch (data.controllerType) {
case 'slider':
const { min, max, step, value } = data.controllerData as SliderData
return (
<Slider
key={data.name}
title={data.title}
min={min}
max={max}
step={step}
value={value}
name={data.name}
register={register}
/>
)
case 'checkbox':
const { checked } = data.controllerData as CheckboxData
return (
<Checkbox
key={data.name}
register={register}
name={data.name}
title={data.title}
checked={checked}
/>
)
default:
return null
}
})
return <div className="flex flex-col gap-y-4">{components}</div>
}
export default settingComponentBuilder

View File

@ -16,7 +16,9 @@ import DropdownListSidebar, {
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import ModelSetting from '../ModelSetting'
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
export const showRightSideBarAtom = atom<boolean>(true)
@ -25,10 +27,12 @@ export default function Sidebar() {
const activeThread = useAtomValue(activeThreadAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const { updateThreadMetadata } = useCreateNewThread()
const threadStates = useAtomValue(threadStatesAtom)
const onReviewInFinderClick = async (type: string) => {
if (!activeThread) return
if (!activeThread.isFinishInit) {
const activeThreadState = threadStates[activeThread.id]
if (!activeThreadState.isFinishInit) {
alert('Thread is not started yet')
return
}
@ -60,7 +64,8 @@ export default function Sidebar() {
const onViewJsonClick = async (type: string) => {
if (!activeThread) return
if (!activeThread.isFinishInit) {
const activeThreadState = threadStates[activeThread.id]
if (!activeThreadState.isFinishInit) {
alert('Thread is not started yet')
return
}
@ -189,6 +194,9 @@ export default function Sidebar() {
>
<div className="p-2">
<DropdownListSidebar />
<div className="mt-4">
<ModelSetting />
</div>
</div>
</CardSidebar>
</div>

View File

@ -12,7 +12,7 @@ import {
import { twMerge } from 'tailwind-merge'
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useDeleteThread from '@/hooks/useDeleteConversation'
import useDeleteThread from '@/hooks/useDeleteThread'
import useGetAllThreads from '@/hooks/useGetAllThreads'
import useGetAssistants from '@/hooks/useGetAssistants'
@ -25,7 +25,7 @@ import {
activeThreadAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
export default function ThreadList() {
const threads = useAtomValue(threadsAtom)

View File

@ -29,9 +29,9 @@ import {
activeThreadAtom,
getActiveThreadIdAtom,
waitingToSendMessage,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'
import { activeThreadStateAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadStateAtom } from '@/helpers/atoms/Thread.atom'
const ChatScreen = () => {
const activeThread = useAtomValue(activeThreadAtom)