feat: adding model params (#886)
* feat: adding model params Signed-off-by: James <james@jan.ai> * chore: inference request parameter * Improve ui right panel model params * Remove unused import * Update slider track for darkmode --------- Signed-off-by: James <james@jan.ai> Co-authored-by: James <james@jan.ai> Co-authored-by: Louis <louis@jan.ai> Co-authored-by: Faisal Amir <urmauur@gmail.com>
This commit is contained in:
parent
5f7001d294
commit
121dc119f1
@ -71,9 +71,9 @@ export type ThreadMessage = {
|
||||
object: string;
|
||||
/** Thread id, default is a ulid. **/
|
||||
thread_id: string;
|
||||
/** The role of the author of this message. **/
|
||||
/** The assistant id of this thread. **/
|
||||
assistant_id?: string;
|
||||
// TODO: comment
|
||||
/** The role of the author of this message. **/
|
||||
role: ChatCompletionRole;
|
||||
/** The content of this message. **/
|
||||
content: ThreadContent[];
|
||||
@ -125,8 +125,6 @@ export interface Thread {
|
||||
title: string;
|
||||
/** Assistants in this thread. **/
|
||||
assistants: ThreadAssistantInfo[];
|
||||
// if the thread has been init will full assistant info
|
||||
isFinishInit: boolean;
|
||||
/** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/
|
||||
created: number;
|
||||
/** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/
|
||||
@ -166,6 +164,7 @@ export type ThreadState = {
|
||||
waitingForResponse: boolean;
|
||||
error?: Error;
|
||||
lastMessage?: string;
|
||||
isFinishInit?: boolean;
|
||||
};
|
||||
/**
|
||||
* Represents the inference engine.
|
||||
@ -291,6 +290,9 @@ export type ModelRuntimeParams = {
|
||||
top_p?: number;
|
||||
stream?: boolean;
|
||||
max_tokens?: number;
|
||||
stop?: string[];
|
||||
frequency_penalty?: number;
|
||||
presence_penalty?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@ -22,10 +22,6 @@ export default class JanAssistantExtension implements AssistantExtension {
|
||||
onUnload(): void {}
|
||||
|
||||
async createAssistant(assistant: Assistant): Promise<void> {
|
||||
// assuming that assistants/ directory is already created in the onLoad above
|
||||
|
||||
// TODO: check if the directory already exists, then ignore creation for now
|
||||
|
||||
const assistantDir = join(JanAssistantExtension._homeDir, assistant.id);
|
||||
await fs.mkdir(assistantDir);
|
||||
|
||||
@ -91,7 +87,7 @@ export default class JanAssistantExtension implements AssistantExtension {
|
||||
avatar: "",
|
||||
thread_location: undefined,
|
||||
id: "jan",
|
||||
object: "assistant", // TODO: maybe we can set default value for this?
|
||||
object: "assistant",
|
||||
created_at: Date.now(),
|
||||
name: "Jan",
|
||||
description: "A default assistant that can use all downloaded models",
|
||||
|
||||
@ -7,7 +7,6 @@ import { Observable } from "rxjs";
|
||||
*/
|
||||
export function requestInference(
|
||||
recentMessages: any[],
|
||||
engine: EngineSettings,
|
||||
model: Model,
|
||||
controller?: AbortController
|
||||
): Observable<string> {
|
||||
@ -23,34 +22,41 @@ export function requestInference(
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
Accept: "text/event-stream",
|
||||
Accept: model.parameters.stream
|
||||
? "text/event-stream"
|
||||
: "application/json",
|
||||
},
|
||||
body: requestBody,
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
if (model.parameters.stream) {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
if (content.startsWith("assistant: ")) {
|
||||
content = content.replace("assistant: ", "");
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
if (content.startsWith("assistant: ")) {
|
||||
content = content.replace("assistant: ", "");
|
||||
}
|
||||
subscriber.next(content);
|
||||
}
|
||||
subscriber.next(content);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const data = await response.json();
|
||||
subscriber.next(data.choices[0]?.message?.content ?? "");
|
||||
}
|
||||
subscriber.complete();
|
||||
})
|
||||
|
||||
@ -85,7 +85,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
|
||||
*/
|
||||
onUnload(): void {}
|
||||
|
||||
|
||||
private async writeDefaultEngineSettings() {
|
||||
try {
|
||||
const engineFile = join(
|
||||
@ -164,7 +163,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
requestInference(
|
||||
data.messages ?? [],
|
||||
JanInferenceNitroExtension._engineSettings,
|
||||
JanInferenceNitroExtension._currentModel
|
||||
).subscribe({
|
||||
next: (_content) => {},
|
||||
@ -210,8 +208,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
|
||||
|
||||
requestInference(
|
||||
data.messages ?? [],
|
||||
JanInferenceNitroExtension._engineSettings,
|
||||
JanInferenceNitroExtension._currentModel,
|
||||
{ ...JanInferenceNitroExtension._currentModel, ...data.model },
|
||||
instance.controller
|
||||
).subscribe({
|
||||
next: (content) => {
|
||||
|
||||
@ -15,9 +15,9 @@ export function requestInference(
|
||||
controller?: AbortController
|
||||
): Observable<string> {
|
||||
return new Observable((subscriber) => {
|
||||
let model_id: string = model.id
|
||||
if (engine.full_url.includes("openai.azure.com")){
|
||||
model_id = engine.full_url.split("/")[5]
|
||||
let model_id: string = model.id;
|
||||
if (engine.full_url.includes("openai.azure.com")) {
|
||||
model_id = engine.full_url.split("/")[5];
|
||||
}
|
||||
const requestBody = JSON.stringify({
|
||||
messages: recentMessages,
|
||||
@ -29,7 +29,9 @@ export function requestInference(
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
Accept: model.parameters.stream
|
||||
? "text/event-stream"
|
||||
: "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
Authorization: `Bearer ${engine.api_key}`,
|
||||
"api-key": `${engine.api_key}`,
|
||||
@ -38,28 +40,33 @@ export function requestInference(
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
if (model.parameters.stream) {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
if (content.startsWith("assistant: ")) {
|
||||
content = content.replace("assistant: ", "");
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
if (content.startsWith("assistant: ")) {
|
||||
content = content.replace("assistant: ", "");
|
||||
}
|
||||
subscriber.next(content);
|
||||
}
|
||||
subscriber.next(content);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const data = await response.json();
|
||||
subscriber.next(data.choices[0]?.message?.content ?? "");
|
||||
}
|
||||
subscriber.complete();
|
||||
})
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
"@radix-ui/react-progress": "^1.0.3",
|
||||
"@radix-ui/react-scroll-area": "^1.0.5",
|
||||
"@radix-ui/react-select": "^2.0.0",
|
||||
"@radix-ui/react-slider": "^1.1.2",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@radix-ui/react-switch": "^1.0.3",
|
||||
"@radix-ui/react-toast": "^1.1.5",
|
||||
|
||||
@ -11,3 +11,4 @@ export * from './modal'
|
||||
export * from './command'
|
||||
export * from './textarea'
|
||||
export * from './select'
|
||||
export * from './slider'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
.input {
|
||||
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-md border bg-transparent px-3 py-1 transition-colors;
|
||||
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-lg border bg-transparent px-3 py-1 transition-colors;
|
||||
@apply disabled:cursor-not-allowed disabled:opacity-50;
|
||||
@apply focus-visible:ring-secondary focus-visible:outline-none focus-visible:ring-1;
|
||||
@apply file:border-0 file:bg-transparent file:font-medium;
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
@import './command/styles.scss';
|
||||
@import './textarea/styles.scss';
|
||||
@import './select/styles.scss';
|
||||
@import './slider/styles.scss';
|
||||
|
||||
.animate-spin {
|
||||
animation: spin 1s linear infinite;
|
||||
|
||||
25
uikit/src/slider/index.tsx
Normal file
25
uikit/src/slider/index.tsx
Normal file
@ -0,0 +1,25 @@
|
||||
'use client'
|
||||
|
||||
import * as React from 'react'
|
||||
import * as SliderPrimitive from '@radix-ui/react-slider'
|
||||
|
||||
import { twMerge } from 'tailwind-merge'
|
||||
|
||||
const Slider = React.forwardRef<
|
||||
React.ElementRef<typeof SliderPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SliderPrimitive.Root
|
||||
ref={ref}
|
||||
className={twMerge('slider', className)}
|
||||
{...props}
|
||||
>
|
||||
<SliderPrimitive.Track className="slider-track">
|
||||
<SliderPrimitive.Range className="slider-range" />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb className="slider-thumb" />
|
||||
</SliderPrimitive.Root>
|
||||
))
|
||||
Slider.displayName = SliderPrimitive.Root.displayName
|
||||
|
||||
export { Slider }
|
||||
15
uikit/src/slider/styles.scss
Normal file
15
uikit/src/slider/styles.scss
Normal file
@ -0,0 +1,15 @@
|
||||
.slider {
|
||||
@apply relative flex w-full touch-none select-none items-center;
|
||||
|
||||
&-track {
|
||||
@apply relative h-1.5 w-full grow overflow-hidden rounded-full bg-gray-200 dark:bg-gray-800;
|
||||
}
|
||||
|
||||
&-range {
|
||||
@apply absolute h-full bg-blue-600;
|
||||
}
|
||||
|
||||
&-thumb {
|
||||
@apply border-primary/50 bg-background focus-visible:ring-ring block h-4 w-4 rounded-full border shadow transition-colors focus-visible:outline-none focus-visible:ring-1 disabled:pointer-events-none disabled:opacity-50;
|
||||
}
|
||||
}
|
||||
@ -32,7 +32,7 @@ export default function CardSidebar({
|
||||
return (
|
||||
<div
|
||||
className={twMerge(
|
||||
'flex w-full flex-col rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
|
||||
'flex w-full flex-col overflow-hidden rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
|
||||
)}
|
||||
>
|
||||
<div
|
||||
@ -43,7 +43,7 @@ export default function CardSidebar({
|
||||
>
|
||||
<button
|
||||
onClick={() => setShow(!show)}
|
||||
className="flex w-full flex-1 items-center space-x-2 px-3 py-2"
|
||||
className="flex w-full flex-1 items-center space-x-2 bg-zinc-200 px-3 py-2 dark:bg-zinc-600/10"
|
||||
>
|
||||
<ChevronDownIcon
|
||||
className={twMerge(
|
||||
|
||||
62
web/containers/Checkbox/index.tsx
Normal file
62
web/containers/Checkbox/index.tsx
Normal file
@ -0,0 +1,62 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
import { Switch } from '@janhq/uikit'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
type Props = {
|
||||
name: string
|
||||
title: string
|
||||
checked: boolean
|
||||
register: any
|
||||
}
|
||||
|
||||
const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => {
|
||||
const [currentChecked, setCurrentChecked] = useState<boolean>(checked)
|
||||
const { updateModelParameter } = useUpdateModelParameters()
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
|
||||
useEffect(() => {
|
||||
setCurrentChecked(checked)
|
||||
}, [checked])
|
||||
|
||||
useEffect(() => {
|
||||
updateSetting()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentChecked])
|
||||
|
||||
const updateValue = [name].reduce((accumulator, value) => {
|
||||
return { ...accumulator, [value]: currentChecked }
|
||||
}, {})
|
||||
|
||||
const updateSetting = () => {
|
||||
return updateModelParameter(String(threadId), {
|
||||
...activeModelParams,
|
||||
...updateValue,
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex justify-between">
|
||||
<label>{title}</label>
|
||||
<Switch
|
||||
checked={currentChecked}
|
||||
{...register(name)}
|
||||
onCheckedChange={(e) => {
|
||||
setCurrentChecked(e)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Checkbox
|
||||
@ -26,7 +26,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'
|
||||
|
||||
import { toGigabytes } from '@/utils/converter'
|
||||
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
|
||||
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const selectedModelAtom = atom<Model | undefined>(undefined)
|
||||
|
||||
@ -36,6 +36,7 @@ export default function DropdownListSidebar() {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const [selected, setSelected] = useState<Model | undefined>()
|
||||
const { setMainViewState } = useMainViewState()
|
||||
|
||||
const { activeModel, stateModel } = useActiveModel()
|
||||
|
||||
useEffect(() => {
|
||||
@ -61,13 +62,22 @@ export default function DropdownListSidebar() {
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [activeThread, activeModel, stateModel.loading])
|
||||
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
if (!activeThread) {
|
||||
return null
|
||||
}
|
||||
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
|
||||
|
||||
const onValueSelected = (value: string) => {
|
||||
setSelected(downloadedModels.filter((x) => x.id === value)[0])
|
||||
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
|
||||
}
|
||||
|
||||
return (
|
||||
<Select
|
||||
disabled={finishInit}
|
||||
value={selected?.id}
|
||||
onValueChange={(value) => {
|
||||
setSelected(downloadedModels.filter((x) => x.id === value)[0])
|
||||
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
|
||||
}}
|
||||
onValueChange={finishInit ? undefined : onValueSelected}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Choose model to start">
|
||||
@ -86,22 +96,20 @@ export default function DropdownListSidebar() {
|
||||
</div>
|
||||
) : (
|
||||
<SelectGroup>
|
||||
{downloadedModels.map((x, i) => {
|
||||
return (
|
||||
<SelectItem
|
||||
key={i}
|
||||
value={x.id}
|
||||
className={twMerge(x.id === selected?.id && 'bg-secondary')}
|
||||
>
|
||||
<div className="flex w-full justify-between">
|
||||
<span className="line-clamp-1 block">{x.name}</span>
|
||||
<span className="font-bold text-muted-foreground">
|
||||
{toGigabytes(x.metadata.size)}
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
)
|
||||
})}
|
||||
{downloadedModels.map((x, i) => (
|
||||
<SelectItem
|
||||
key={i}
|
||||
value={x.id}
|
||||
className={twMerge(x.id === selected?.id && 'bg-secondary')}
|
||||
>
|
||||
<div className="flex w-full justify-between">
|
||||
<span className="line-clamp-1 block">{x.name}</span>
|
||||
<span className="font-bold text-muted-foreground">
|
||||
{toGigabytes(x.metadata.size)}
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
)}
|
||||
<div className="border-b border-border" />
|
||||
|
||||
@ -12,7 +12,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'
|
||||
|
||||
import { showRightSideBarAtom } from '@/screens/Chat/Sidebar'
|
||||
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const TopBar = () => {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
|
||||
@ -25,7 +25,7 @@ import {
|
||||
import {
|
||||
updateThreadWaitingForResponseAtom,
|
||||
threadsAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||
|
||||
101
web/containers/Slider/index.tsx
Normal file
101
web/containers/Slider/index.tsx
Normal file
@ -0,0 +1,101 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
import { Slider, Input } from '@janhq/uikit'
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
type Props = {
|
||||
name: string
|
||||
title: string
|
||||
min: number
|
||||
max: number
|
||||
step: number
|
||||
value: number
|
||||
register: any
|
||||
}
|
||||
|
||||
const SliderRightPanel: React.FC<Props> = ({
|
||||
name,
|
||||
title,
|
||||
min,
|
||||
max,
|
||||
step,
|
||||
value,
|
||||
register,
|
||||
}) => {
|
||||
const [currentValue, setCurrentValue] = useState<number>(value)
|
||||
const { updateModelParameter } = useUpdateModelParameters()
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
|
||||
useEffect(() => {
|
||||
setCurrentValue(value)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [value])
|
||||
|
||||
useEffect(() => {
|
||||
updateSetting()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentValue])
|
||||
|
||||
const updateValue = [name].reduce((accumulator, value) => {
|
||||
return { ...accumulator, [value]: currentValue }
|
||||
}, {})
|
||||
|
||||
const updateSetting = () => {
|
||||
return updateModelParameter(String(threadId), {
|
||||
...activeModelParams,
|
||||
...updateValue,
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col">
|
||||
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
|
||||
<div className="flex items-center gap-x-4">
|
||||
<div className="relative w-full">
|
||||
<Slider
|
||||
{...register(name, {
|
||||
setValueAs: (v: any) => parseInt(v),
|
||||
})}
|
||||
value={[currentValue]}
|
||||
onValueChange={async (e) => {
|
||||
setCurrentValue(Number(e[0]))
|
||||
await updateSetting()
|
||||
}}
|
||||
type="range"
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
/>
|
||||
<div className="relative mt-2 flex items-center justify-between text-gray-400">
|
||||
<p className="text-sm">{min}</p>
|
||||
<p className="absolute left-1/2 -translate-x-1/2 text-sm">
|
||||
{max / 2}
|
||||
</p>
|
||||
<p className="text-sm">{max}</p>
|
||||
</div>
|
||||
</div>
|
||||
<Input
|
||||
className="-mt-4 h-8 w-16"
|
||||
min={min}
|
||||
max={max}
|
||||
value={String(currentValue)}
|
||||
onChange={async (e) => {
|
||||
setCurrentValue(Number(e.target.value))
|
||||
await updateSetting()
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default SliderRightPanel
|
||||
@ -9,7 +9,7 @@ import { atom } from 'jotai'
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
updateThreadStateLastMessageAtom,
|
||||
} from './Conversation.atom'
|
||||
} from './Thread.atom'
|
||||
|
||||
/**
|
||||
* Stores all chat messages for all threads
|
||||
@ -76,15 +76,18 @@ export const addNewMessageAtom = atom(
|
||||
}
|
||||
)
|
||||
|
||||
export const deleteConversationMessage = atom(null, (get, set, id: string) => {
|
||||
const newData: Record<string, ThreadMessage[]> = {
|
||||
...get(chatMessages),
|
||||
export const deleteChatMessageAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string) => {
|
||||
const newData: Record<string, ThreadMessage[]> = {
|
||||
...get(chatMessages),
|
||||
}
|
||||
newData[threadId] = []
|
||||
set(chatMessages, newData)
|
||||
}
|
||||
newData[id] = []
|
||||
set(chatMessages, newData)
|
||||
})
|
||||
)
|
||||
|
||||
export const cleanConversationMessages = atom(null, (get, set, id: string) => {
|
||||
export const cleanChatMessageAtom = atom(null, (get, set, id: string) => {
|
||||
const newData: Record<string, ThreadMessage[]> = {
|
||||
...get(chatMessages),
|
||||
}
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
import { Thread, ThreadContent, ThreadState } from '@janhq/core'
|
||||
import {
|
||||
ModelRuntimeParams,
|
||||
Thread,
|
||||
ThreadContent,
|
||||
ThreadState,
|
||||
} from '@janhq/core'
|
||||
import { atom } from 'jotai'
|
||||
|
||||
/**
|
||||
* Stores the current active conversation id.
|
||||
* Stores the current active thread id.
|
||||
*/
|
||||
const activeThreadIdAtom = atom<string | undefined>(undefined)
|
||||
|
||||
@ -10,7 +15,7 @@ export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom))
|
||||
|
||||
export const setActiveThreadIdAtom = atom(
|
||||
null,
|
||||
(_get, set, convoId: string | undefined) => set(activeThreadIdAtom, convoId)
|
||||
(_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId)
|
||||
)
|
||||
|
||||
export const waitingToSendMessage = atom<boolean | undefined>(undefined)
|
||||
@ -20,47 +25,48 @@ export const waitingToSendMessage = atom<boolean | undefined>(undefined)
|
||||
*/
|
||||
export const threadStatesAtom = atom<Record<string, ThreadState>>({})
|
||||
export const activeThreadStateAtom = atom<ThreadState | undefined>((get) => {
|
||||
const activeConvoId = get(activeThreadIdAtom)
|
||||
if (!activeConvoId) {
|
||||
console.debug('Active convo id is undefined')
|
||||
const threadId = get(activeThreadIdAtom)
|
||||
if (!threadId) {
|
||||
console.debug('Active thread id is undefined')
|
||||
return undefined
|
||||
}
|
||||
|
||||
return get(threadStatesAtom)[activeConvoId]
|
||||
return get(threadStatesAtom)[threadId]
|
||||
})
|
||||
|
||||
export const deleteThreadStateAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string) => {
|
||||
const currentState = { ...get(threadStatesAtom) }
|
||||
delete currentState[threadId]
|
||||
set(threadStatesAtom, currentState)
|
||||
}
|
||||
)
|
||||
|
||||
export const updateThreadInitSuccessAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string) => {
|
||||
const currentState = { ...get(threadStatesAtom) }
|
||||
currentState[threadId] = {
|
||||
...currentState[threadId],
|
||||
isFinishInit: true,
|
||||
}
|
||||
set(threadStatesAtom, currentState)
|
||||
}
|
||||
)
|
||||
|
||||
export const updateThreadWaitingForResponseAtom = atom(
|
||||
null,
|
||||
(get, set, conversationId: string, waitingForResponse: boolean) => {
|
||||
(get, set, threadId: string, waitingForResponse: boolean) => {
|
||||
const currentState = { ...get(threadStatesAtom) }
|
||||
currentState[conversationId] = {
|
||||
...currentState[conversationId],
|
||||
currentState[threadId] = {
|
||||
...currentState[threadId],
|
||||
waitingForResponse,
|
||||
error: undefined,
|
||||
}
|
||||
set(threadStatesAtom, currentState)
|
||||
}
|
||||
)
|
||||
export const updateConversationErrorAtom = atom(
|
||||
null,
|
||||
(get, set, conversationId: string, error?: Error) => {
|
||||
const currentState = { ...get(threadStatesAtom) }
|
||||
currentState[conversationId] = {
|
||||
...currentState[conversationId],
|
||||
error,
|
||||
}
|
||||
set(threadStatesAtom, currentState)
|
||||
}
|
||||
)
|
||||
export const updateConversationHasMoreAtom = atom(
|
||||
null,
|
||||
(get, set, conversationId: string, hasMore: boolean) => {
|
||||
const currentState = { ...get(threadStatesAtom) }
|
||||
currentState[conversationId] = { ...currentState[conversationId], hasMore }
|
||||
set(threadStatesAtom, currentState)
|
||||
}
|
||||
)
|
||||
|
||||
export const updateThreadStateLastMessageAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string, lastContent?: ThreadContent[]) => {
|
||||
@ -100,3 +106,42 @@ export const threadsAtom = atom<Thread[]>([])
|
||||
export const activeThreadAtom = atom<Thread | undefined>((get) =>
|
||||
get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
|
||||
)
|
||||
|
||||
/**
|
||||
* Store model params at thread level settings
|
||||
*/
|
||||
export const threadModelRuntimeParamsAtom = atom<
|
||||
Record<string, ModelRuntimeParams>
|
||||
>({})
|
||||
|
||||
export const getActiveThreadModelRuntimeParamsAtom = atom<
|
||||
ModelRuntimeParams | undefined
|
||||
>((get) => {
|
||||
const threadId = get(activeThreadIdAtom)
|
||||
if (!threadId) {
|
||||
console.debug('Active thread id is undefined')
|
||||
return undefined
|
||||
}
|
||||
|
||||
return get(threadModelRuntimeParamsAtom)[threadId]
|
||||
})
|
||||
|
||||
export const getThreadModelRuntimeParamsAtom = atom(
|
||||
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
|
||||
)
|
||||
|
||||
export const setThreadModelRuntimeParamsAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string, params: ModelRuntimeParams) => {
|
||||
const currentState = { ...get(threadModelRuntimeParamsAtom) }
|
||||
currentState[threadId] = params
|
||||
console.debug(
|
||||
`Update model params for thread ${threadId}, ${JSON.stringify(
|
||||
params,
|
||||
null,
|
||||
2
|
||||
)}`
|
||||
)
|
||||
set(threadModelRuntimeParamsAtom, currentState)
|
||||
}
|
||||
)
|
||||
@ -1,8 +1,5 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import {
|
||||
EventName,
|
||||
events,
|
||||
} from '@janhq/core'
|
||||
import { EventName, events } from '@janhq/core'
|
||||
import { Model, ModelSettingParams } from '@janhq/core'
|
||||
import { atom, useAtom } from 'jotai'
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ import {
|
||||
ThreadAssistantInfo,
|
||||
ThreadState,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { generateThreadId } from '@/utils/conversation'
|
||||
import { generateThreadId } from '@/utils/thread'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
@ -16,7 +16,8 @@ import {
|
||||
setActiveThreadIdAtom,
|
||||
threadStatesAtom,
|
||||
updateThreadAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
setThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||
// create thread state for this new thread
|
||||
@ -25,6 +26,8 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||
const threadState: ThreadState = {
|
||||
hasMore: false,
|
||||
waitingForResponse: false,
|
||||
lastMessage: undefined,
|
||||
isFinishInit: false,
|
||||
}
|
||||
currentState[newThread.id] = threadState
|
||||
set(threadStatesAtom, currentState)
|
||||
@ -35,15 +38,26 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||
})
|
||||
|
||||
export const useCreateNewThread = () => {
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const createNewThread = useSetAtom(createNewThreadAtom)
|
||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||
const [threadStates, setThreadStates] = useAtom(threadStatesAtom)
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
const setThreadModelRuntimeParams = useSetAtom(
|
||||
setThreadModelRuntimeParamsAtom
|
||||
)
|
||||
|
||||
const requestCreateNewThread = async (assistant: Assistant) => {
|
||||
const unfinishedThreads = threads.filter((t) => t.isFinishInit === false)
|
||||
if (unfinishedThreads.length > 0) {
|
||||
// loop through threads state and filter if there's any thread that is not finish init
|
||||
let hasUnfinishedInitThread = false
|
||||
for (const key in threadStates) {
|
||||
const isFinishInit = threadStates[key].isFinishInit ?? true
|
||||
if (!isFinishInit) {
|
||||
hasUnfinishedInitThread = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (hasUnfinishedInitThread) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -53,18 +67,10 @@ export const useCreateNewThread = () => {
|
||||
assistant_name: assistant.name,
|
||||
model: {
|
||||
id: '*',
|
||||
settings: {
|
||||
ctx_len: 0,
|
||||
ngl: 0,
|
||||
embedding: false,
|
||||
n_parallel: 0,
|
||||
},
|
||||
settings: {},
|
||||
parameters: {
|
||||
temperature: 0,
|
||||
token_limit: 0,
|
||||
top_k: 0,
|
||||
top_p: 0,
|
||||
stream: false,
|
||||
stream: true,
|
||||
max_tokens: 1024,
|
||||
},
|
||||
engine: undefined,
|
||||
},
|
||||
@ -78,29 +84,20 @@ export const useCreateNewThread = () => {
|
||||
assistants: [assistantInfo],
|
||||
created: createdAt,
|
||||
updated: createdAt,
|
||||
isFinishInit: false,
|
||||
}
|
||||
|
||||
// TODO: move isFinishInit here
|
||||
const threadState: ThreadState = {
|
||||
hasMore: false,
|
||||
waitingForResponse: false,
|
||||
lastMessage: undefined,
|
||||
}
|
||||
setThreadStates({ ...threadStates, [threadId]: threadState })
|
||||
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)
|
||||
|
||||
// add the new thread on top of the thread list to the state
|
||||
createNewThread(thread)
|
||||
setActiveThreadId(thread.id)
|
||||
}
|
||||
|
||||
function updateThreadMetadata(thread: Thread) {
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
}
|
||||
updateThread(updatedThread)
|
||||
updateThread(thread)
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
?.saveThread(thread)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@ -11,14 +11,15 @@ import { useActiveModel } from './useActiveModel'
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
|
||||
import {
|
||||
cleanConversationMessages,
|
||||
deleteConversationMessage,
|
||||
cleanChatMessageAtom as cleanChatMessagesAtom,
|
||||
deleteChatMessageAtom as deleteChatMessagesAtom,
|
||||
getCurrentChatMessagesAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import {
|
||||
threadsAtom,
|
||||
setActiveThreadIdAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
deleteThreadStateAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useDeleteThread() {
|
||||
const { activeModel } = useActiveModel()
|
||||
@ -26,45 +27,51 @@ export default function useDeleteThread() {
|
||||
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
||||
const messages = useAtomValue(getCurrentChatMessagesAtom)
|
||||
|
||||
const setActiveConvoId = useSetAtom(setActiveThreadIdAtom)
|
||||
const deleteMessages = useSetAtom(deleteConversationMessage)
|
||||
const cleanMessages = useSetAtom(cleanConversationMessages)
|
||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
|
||||
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
|
||||
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
|
||||
|
||||
const cleanThread = async (threadId: string) => {
|
||||
if (threadId) {
|
||||
const thread = threads.filter((c) => c.id === threadId)[0]
|
||||
cleanMessages(threadId)
|
||||
|
||||
const cleanThread = async (activeThreadId: string) => {
|
||||
if (activeThreadId) {
|
||||
const thread = threads.filter((c) => c.id === activeThreadId)[0]
|
||||
cleanMessages(activeThreadId)
|
||||
if (thread)
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.writeMessages(
|
||||
activeThreadId,
|
||||
threadId,
|
||||
messages.filter((msg) => msg.role === ChatCompletionRole.System)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const deleteThread = async (activeThreadId: string) => {
|
||||
if (!activeThreadId) {
|
||||
const deleteThread = async (threadId: string) => {
|
||||
if (!threadId) {
|
||||
alert('No active thread')
|
||||
return
|
||||
}
|
||||
try {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.deleteThread(activeThreadId)
|
||||
const availableThreads = threads.filter((c) => c.id !== activeThreadId)
|
||||
?.deleteThread(threadId)
|
||||
const availableThreads = threads.filter((c) => c.id !== threadId)
|
||||
setThreads(availableThreads)
|
||||
deleteMessages(activeThreadId)
|
||||
|
||||
// delete the thread state
|
||||
deleteThreadState(threadId)
|
||||
|
||||
deleteMessages(threadId)
|
||||
setCurrentPrompt('')
|
||||
toaster({
|
||||
title: 'Thread successfully deleted.',
|
||||
description: `Thread with ${activeModel?.name} has been successfully deleted.`,
|
||||
})
|
||||
if (availableThreads.length > 0) {
|
||||
setActiveConvoId(availableThreads[0].id)
|
||||
setActiveThreadId(availableThreads[0].id)
|
||||
} else {
|
||||
setActiveConvoId(undefined)
|
||||
setActiveThreadId(undefined)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
@ -1,35 +1,50 @@
|
||||
import { ExtensionType, ThreadState } from '@janhq/core'
|
||||
import { ExtensionType, ModelRuntimeParams, ThreadState } from '@janhq/core'
|
||||
import { ConversationalExtension } from '@janhq/core'
|
||||
import { useSetAtom } from 'jotai'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import {
|
||||
threadModelRuntimeParamsAtom,
|
||||
threadStatesAtom,
|
||||
threadsAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const useGetAllThreads = () => {
|
||||
const setConversationStates = useSetAtom(threadStatesAtom)
|
||||
const setConversations = useSetAtom(threadsAtom)
|
||||
const setThreadStates = useSetAtom(threadStatesAtom)
|
||||
const setThreads = useSetAtom(threadsAtom)
|
||||
const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom)
|
||||
|
||||
const getAllThreads = async () => {
|
||||
try {
|
||||
const threads = await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.getThreads()
|
||||
const threads =
|
||||
(await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.getThreads()) ?? []
|
||||
|
||||
const threadStates: Record<string, ThreadState> = {}
|
||||
threads?.forEach((thread) => {
|
||||
const threadModelParams: Record<string, ModelRuntimeParams> = {}
|
||||
|
||||
threads.forEach((thread) => {
|
||||
if (thread.id != null) {
|
||||
const lastMessage = (thread.metadata?.lastMessage as string) ?? ''
|
||||
|
||||
threadStates[thread.id] = {
|
||||
hasMore: true,
|
||||
waitingForResponse: false,
|
||||
lastMessage,
|
||||
isFinishInit: true,
|
||||
}
|
||||
|
||||
// model params
|
||||
const modelParams = thread.assistants?.[0]?.model?.parameters
|
||||
threadModelParams[thread.id] = modelParams
|
||||
}
|
||||
})
|
||||
setConversationStates(threadStates)
|
||||
setConversations(threads ?? [])
|
||||
|
||||
// updating app states
|
||||
setThreadStates(threadStates)
|
||||
setThreads(threads)
|
||||
setThreadModelRuntimeParams(threadModelParams)
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
|
||||
@ -19,9 +19,6 @@ export function useGetConfiguredModels() {
|
||||
async function fetchModels() {
|
||||
setLoading(true)
|
||||
const models = await getConfiguredModels()
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
// models = [dummyModel, ...models] // TODO: NamH add back dummy model later
|
||||
}
|
||||
setLoading(false)
|
||||
setModels(models)
|
||||
}
|
||||
|
||||
@ -12,8 +12,9 @@ import {
|
||||
ThreadMessage,
|
||||
events,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
ModelRuntimeParams,
|
||||
} from '@janhq/core'
|
||||
import { ConversationalExtension } from '@janhq/core'
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { ulid } from 'ulid'
|
||||
@ -32,9 +33,12 @@ import {
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
threadStatesAtom,
|
||||
updateThreadAtom,
|
||||
updateThreadInitSuccessAtom,
|
||||
updateThreadWaitingForResponseAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useSendChatMessage() {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
@ -50,6 +54,10 @@ export default function useSendChatMessage() {
|
||||
const [queuedMessage, setQueuedMessage] = useState(false)
|
||||
|
||||
const modelRef = useRef<Model | undefined>()
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const updateThreadInitSuccess = useSetAtom(updateThreadInitSuccessAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
|
||||
useEffect(() => {
|
||||
modelRef.current = activeModel
|
||||
}, [activeModel])
|
||||
@ -109,7 +117,7 @@ export default function useSendChatMessage() {
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(async () => {
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
console.log('waiting for model to start')
|
||||
console.debug('waiting for model to start')
|
||||
await WaitForModelStarting(modelId)
|
||||
resolve()
|
||||
} else {
|
||||
@ -127,8 +135,10 @@ export default function useSendChatMessage() {
|
||||
console.error('No active thread')
|
||||
return
|
||||
}
|
||||
const activeThreadState = threadStates[activeThread.id]
|
||||
|
||||
if (!activeThread.isFinishInit) {
|
||||
// if the thread is not initialized, we need to initialize it first
|
||||
if (!activeThreadState.isFinishInit) {
|
||||
if (!selectedModel) {
|
||||
toaster({ title: 'Please select a model' })
|
||||
return
|
||||
@ -136,9 +146,14 @@ export default function useSendChatMessage() {
|
||||
const assistantId = activeThread.assistants[0].assistant_id ?? ''
|
||||
const assistantName = activeThread.assistants[0].assistant_name ?? ''
|
||||
const instructions = activeThread.assistants[0].instructions ?? ''
|
||||
|
||||
const modelParams: ModelRuntimeParams = {
|
||||
...selectedModel.parameters,
|
||||
...activeModelParams,
|
||||
}
|
||||
|
||||
const updatedThread: Thread = {
|
||||
...activeThread,
|
||||
isFinishInit: true,
|
||||
assistants: [
|
||||
{
|
||||
assistant_id: assistantId,
|
||||
@ -147,13 +162,13 @@ export default function useSendChatMessage() {
|
||||
model: {
|
||||
id: selectedModel.id,
|
||||
settings: selectedModel.settings,
|
||||
parameters: selectedModel.parameters,
|
||||
parameters: modelParams,
|
||||
engine: selectedModel.engine,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
updateThreadInitSuccess(activeThread.id)
|
||||
updateThread(updatedThread)
|
||||
|
||||
extensionManager
|
||||
@ -191,11 +206,16 @@ export default function useSendChatMessage() {
|
||||
])
|
||||
)
|
||||
const msgId = ulid()
|
||||
|
||||
const modelRequest = selectedModel ?? activeThread.assistants[0].model
|
||||
const messageRequest: MessageRequest = {
|
||||
id: msgId,
|
||||
threadId: activeThread.id,
|
||||
messages,
|
||||
model: selectedModel ?? activeThread.assistants[0].model,
|
||||
model: {
|
||||
...modelRequest,
|
||||
...(activeModelParams ? { parameters: activeModelParams } : {}),
|
||||
},
|
||||
}
|
||||
const timestamp = Date.now()
|
||||
const threadMessage: ThreadMessage = {
|
||||
|
||||
@ -9,7 +9,7 @@ import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
setActiveThreadIdAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useSetActiveThread() {
|
||||
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
|
||||
|
||||
66
web/hooks/useUpdateModelParameters.ts
Normal file
66
web/hooks/useUpdateModelParameters.ts
Normal file
@ -0,0 +1,66 @@
|
||||
import {
|
||||
ConversationalExtension,
|
||||
ExtensionType,
|
||||
ModelRuntimeParams,
|
||||
Thread,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
activeThreadStateAtom,
|
||||
setThreadModelRuntimeParamsAtom,
|
||||
threadsAtom,
|
||||
updateThreadAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useUpdateModelParameters() {
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
const setThreadModelRuntimeParams = useSetAtom(
|
||||
setThreadModelRuntimeParamsAtom
|
||||
)
|
||||
const activeThreadState = useAtomValue(activeThreadStateAtom)
|
||||
|
||||
const updateModelParameter = async (
|
||||
threadId: string,
|
||||
params: ModelRuntimeParams
|
||||
) => {
|
||||
const thread = threads.find((thread) => thread.id === threadId)
|
||||
if (!thread) {
|
||||
console.error(`Thread ${threadId} not found`)
|
||||
return
|
||||
}
|
||||
|
||||
if (!activeThreadState) {
|
||||
console.error('No active thread')
|
||||
return
|
||||
}
|
||||
|
||||
// update the state
|
||||
setThreadModelRuntimeParams(thread.id, params)
|
||||
|
||||
if (!activeThreadState.isFinishInit) {
|
||||
// if thread is not initialized, we don't need to update thread.json
|
||||
return
|
||||
}
|
||||
|
||||
const assistants = thread.assistants.map((assistant) => {
|
||||
assistant.model.parameters = params
|
||||
return assistant
|
||||
})
|
||||
|
||||
// update thread
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
assistants,
|
||||
}
|
||||
updateThread(updatedThread)
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
}
|
||||
|
||||
return { updateModelParameter }
|
||||
}
|
||||
@ -38,7 +38,6 @@
|
||||
"sass": "^1.69.4",
|
||||
"tailwind-merge": "^2.0.0",
|
||||
"tailwindcss": "3.3.5",
|
||||
"typescript": "5.2.2",
|
||||
"ulid": "^2.3.0",
|
||||
"uuid": "^9.0.1",
|
||||
"zod": "^3.22.4"
|
||||
|
||||
@ -21,7 +21,7 @@ import {
|
||||
deleteMessageAtom,
|
||||
getCurrentChatMessagesAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
||||
const deleteMessage = useSetAtom(deleteMessageAtom)
|
||||
|
||||
60
web/screens/Chat/ModelSetting/index.tsx
Normal file
60
web/screens/Chat/ModelSetting/index.tsx
Normal file
@ -0,0 +1,60 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
import { useForm } from 'react-hook-form'
|
||||
|
||||
import { ModelRuntimeParams } from '@janhq/core'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import { presetConfiguration } from './predefinedComponent'
|
||||
import settingComponentBuilder, {
|
||||
SettingComponentData,
|
||||
} from './settingComponentBuilder'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function ModelSetting() {
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
const [modelParams, setModelParams] = useState<
|
||||
ModelRuntimeParams | undefined
|
||||
>(activeModelParams)
|
||||
|
||||
const { register } = useForm()
|
||||
|
||||
useEffect(() => {
|
||||
setModelParams(activeModelParams)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [threadId])
|
||||
|
||||
if (!modelParams) {
|
||||
return <div>This thread has no model parameters</div>
|
||||
}
|
||||
|
||||
const componentData: SettingComponentData[] = []
|
||||
Object.keys(modelParams).forEach((key) => {
|
||||
const componentSetting = presetConfiguration[key]
|
||||
|
||||
if (componentSetting) {
|
||||
if ('value' in componentSetting.controllerData) {
|
||||
componentSetting.controllerData.value = Number(
|
||||
modelParams[key as keyof ModelRuntimeParams]
|
||||
)
|
||||
} else if ('checked' in componentSetting.controllerData) {
|
||||
componentSetting.controllerData.checked = modelParams[
|
||||
key as keyof ModelRuntimeParams
|
||||
] as boolean
|
||||
}
|
||||
componentData.push(componentSetting)
|
||||
}
|
||||
})
|
||||
|
||||
return (
|
||||
<form className="flex flex-col">
|
||||
{settingComponentBuilder(componentData, register)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
59
web/screens/Chat/ModelSetting/predefinedComponent.ts
Normal file
59
web/screens/Chat/ModelSetting/predefinedComponent.ts
Normal file
@ -0,0 +1,59 @@
|
||||
import { SettingComponentData } from './settingComponentBuilder'
|
||||
|
||||
export const presetConfiguration: Record<string, SettingComponentData> = {
|
||||
max_tokens: {
|
||||
name: 'max_tokens',
|
||||
title: 'Max Tokens',
|
||||
description: 'Maximum context length the model can handle.',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 4096,
|
||||
step: 128,
|
||||
value: 2048,
|
||||
},
|
||||
},
|
||||
ngl: {
|
||||
name: 'ngl',
|
||||
title: 'NGL',
|
||||
description: 'Number of layers in the neural network.',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 1,
|
||||
max: 100,
|
||||
step: 1,
|
||||
value: 100,
|
||||
},
|
||||
},
|
||||
embedding: {
|
||||
name: 'embedding',
|
||||
title: 'Embedding',
|
||||
description: 'Indicates if embedding layers are used.',
|
||||
controllerType: 'checkbox',
|
||||
controllerData: {
|
||||
checked: true,
|
||||
},
|
||||
},
|
||||
stream: {
|
||||
name: 'stream',
|
||||
title: 'Stream',
|
||||
description: 'Stream',
|
||||
controllerType: 'checkbox',
|
||||
controllerData: {
|
||||
checked: false,
|
||||
},
|
||||
},
|
||||
temperature: {
|
||||
name: 'temperature',
|
||||
title: 'Temperature',
|
||||
description:
|
||||
"Controls randomness in model's responses. Higher values lead to more random responses.",
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 2,
|
||||
step: 0.1,
|
||||
value: 0.7,
|
||||
},
|
||||
},
|
||||
}
|
||||
67
web/screens/Chat/ModelSetting/settingComponentBuilder.tsx
Normal file
67
web/screens/Chat/ModelSetting/settingComponentBuilder.tsx
Normal file
@ -0,0 +1,67 @@
|
||||
/* eslint-disable no-case-declarations */
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import Checkbox from '@/containers/Checkbox'
|
||||
import Slider from '@/containers/Slider'
|
||||
|
||||
export type ControllerType = 'slider' | 'checkbox'
|
||||
|
||||
export type SettingComponentData = {
|
||||
name: string
|
||||
title: string
|
||||
description: string
|
||||
controllerType: ControllerType
|
||||
controllerData: SliderData | CheckboxData
|
||||
}
|
||||
|
||||
export type SliderData = {
|
||||
min: number
|
||||
max: number
|
||||
step: number
|
||||
value: number
|
||||
}
|
||||
|
||||
type CheckboxData = {
|
||||
checked: boolean
|
||||
}
|
||||
|
||||
const settingComponentBuilder = (
|
||||
componentData: SettingComponentData[],
|
||||
register: any
|
||||
) => {
|
||||
const components = componentData.map((data) => {
|
||||
switch (data.controllerType) {
|
||||
case 'slider':
|
||||
const { min, max, step, value } = data.controllerData as SliderData
|
||||
return (
|
||||
<Slider
|
||||
key={data.name}
|
||||
title={data.title}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
value={value}
|
||||
name={data.name}
|
||||
register={register}
|
||||
/>
|
||||
)
|
||||
case 'checkbox':
|
||||
const { checked } = data.controllerData as CheckboxData
|
||||
return (
|
||||
<Checkbox
|
||||
key={data.name}
|
||||
register={register}
|
||||
name={data.name}
|
||||
title={data.title}
|
||||
checked={checked}
|
||||
/>
|
||||
)
|
||||
default:
|
||||
return null
|
||||
}
|
||||
})
|
||||
|
||||
return <div className="flex flex-col gap-y-4">{components}</div>
|
||||
}
|
||||
|
||||
export default settingComponentBuilder
|
||||
@ -16,7 +16,9 @@ import DropdownListSidebar, {
|
||||
|
||||
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
|
||||
import ModelSetting from '../ModelSetting'
|
||||
|
||||
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const showRightSideBarAtom = atom<boolean>(true)
|
||||
|
||||
@ -25,10 +27,12 @@ export default function Sidebar() {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const { updateThreadMetadata } = useCreateNewThread()
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
|
||||
const onReviewInFinderClick = async (type: string) => {
|
||||
if (!activeThread) return
|
||||
if (!activeThread.isFinishInit) {
|
||||
const activeThreadState = threadStates[activeThread.id]
|
||||
if (!activeThreadState.isFinishInit) {
|
||||
alert('Thread is not started yet')
|
||||
return
|
||||
}
|
||||
@ -60,7 +64,8 @@ export default function Sidebar() {
|
||||
|
||||
const onViewJsonClick = async (type: string) => {
|
||||
if (!activeThread) return
|
||||
if (!activeThread.isFinishInit) {
|
||||
const activeThreadState = threadStates[activeThread.id]
|
||||
if (!activeThreadState.isFinishInit) {
|
||||
alert('Thread is not started yet')
|
||||
return
|
||||
}
|
||||
@ -189,6 +194,9 @@ export default function Sidebar() {
|
||||
>
|
||||
<div className="p-2">
|
||||
<DropdownListSidebar />
|
||||
<div className="mt-4">
|
||||
<ModelSetting />
|
||||
</div>
|
||||
</div>
|
||||
</CardSidebar>
|
||||
</div>
|
||||
|
||||
@ -12,7 +12,7 @@ import {
|
||||
import { twMerge } from 'tailwind-merge'
|
||||
|
||||
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
import useDeleteThread from '@/hooks/useDeleteConversation'
|
||||
import useDeleteThread from '@/hooks/useDeleteThread'
|
||||
import useGetAllThreads from '@/hooks/useGetAllThreads'
|
||||
|
||||
import useGetAssistants from '@/hooks/useGetAssistants'
|
||||
@ -25,7 +25,7 @@ import {
|
||||
activeThreadAtom,
|
||||
threadStatesAtom,
|
||||
threadsAtom,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function ThreadList() {
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
|
||||
@ -29,9 +29,9 @@ import {
|
||||
activeThreadAtom,
|
||||
getActiveThreadIdAtom,
|
||||
waitingToSendMessage,
|
||||
} from '@/helpers/atoms/Conversation.atom'
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
import { activeThreadStateAtom } from '@/helpers/atoms/Conversation.atom'
|
||||
import { activeThreadStateAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const ChatScreen = () => {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user