feat: reroute threads and messages requests to the backend

This commit is contained in:
Louis 2024-12-05 17:33:43 +07:00
parent 14737b7e31
commit 174f1c7dcb
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
36 changed files with 774 additions and 1120 deletions

View File

@ -1,4 +1,10 @@
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types' import {
Thread,
ThreadInterface,
ThreadMessage,
MessageInterface,
ThreadAssistantInfo,
} from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
/** /**
@ -17,10 +23,21 @@ export abstract class ConversationalExtension
return ExtensionTypeEnum.Conversational return ExtensionTypeEnum.Conversational
} }
abstract getThreads(): Promise<Thread[]> abstract listThreads(): Promise<Thread[]>
abstract saveThread(thread: Thread): Promise<void> abstract createThread(thread: Partial<Thread>): Promise<Thread>
abstract modifyThread(thread: Thread): Promise<void>
abstract deleteThread(threadId: string): Promise<void> abstract deleteThread(threadId: string): Promise<void>
abstract addNewMessage(message: ThreadMessage): Promise<void> abstract createMessage(message: Partial<ThreadMessage>): Promise<ThreadMessage>
abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void> abstract deleteMessage(threadId: string, messageId: string): Promise<void>
abstract getAllMessages(threadId: string): Promise<ThreadMessage[]> abstract listMessages(threadId: string): Promise<ThreadMessage[]>
abstract getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo>
abstract createThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo>
abstract modifyThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo>
abstract modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
} }

View File

@ -2,7 +2,6 @@ import { events } from '../../events'
import { BaseExtension } from '../../extension' import { BaseExtension } from '../../extension'
import { MessageRequest, Model, ModelEvent } from '../../../types' import { MessageRequest, Model, ModelEvent } from '../../../types'
import { EngineManager } from './EngineManager' import { EngineManager } from './EngineManager'
import { ModelManager } from '../../models/manager'
/** /**
* Base AIEngine * Base AIEngine

View File

@ -6,7 +6,6 @@ import {
mkdirSync, mkdirSync,
appendFileSync, appendFileSync,
createWriteStream, createWriteStream,
rmdirSync,
} from 'fs' } from 'fs'
import { JanApiRouteConfiguration, RouteConfiguration } from './configuration' import { JanApiRouteConfiguration, RouteConfiguration } from './configuration'
import { join } from 'path' import { join } from 'path'
@ -126,7 +125,7 @@ export const createThread = async (thread: any) => {
} }
} }
const threadId = generateThreadId(thread.assistants[0].assistant_id) const threadId = generateThreadId(thread.assistants[0]?.assistant_id)
try { try {
const updatedThread = { const updatedThread = {
...thread, ...thread,
@ -280,7 +279,7 @@ export const models = async (request: any, reply: any) => {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ""}`, { const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ''}`, {
method: request.method, method: request.method,
headers: headers, headers: headers,
body: JSON.stringify(request.body), body: JSON.stringify(request.body),

View File

@ -11,20 +11,20 @@ export interface MessageInterface {
* @param {ThreadMessage} message - The message to be added. * @param {ThreadMessage} message - The message to be added.
* @returns {Promise<void>} A promise that resolves when the message has been added. * @returns {Promise<void>} A promise that resolves when the message has been added.
*/ */
addNewMessage(message: ThreadMessage): Promise<void> createMessage(message: ThreadMessage): Promise<ThreadMessage>
/**
* Writes an array of messages to a specific thread.
* @param {string} threadId - The ID of the thread to write the messages to.
* @param {ThreadMessage[]} messages - The array of messages to be written.
* @returns {Promise<void>} A promise that resolves when the messages have been written.
*/
writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
/** /**
* Retrieves all messages from a specific thread. * Retrieves all messages from a specific thread.
* @param {string} threadId - The ID of the thread to retrieve the messages from. * @param {string} threadId - The ID of the thread to retrieve the messages from.
* @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread. * @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread.
*/ */
getAllMessages(threadId: string): Promise<ThreadMessage[]> listMessages(threadId: string): Promise<ThreadMessage[]>
/**
* Deletes a specific message from a thread.
* @param {string} threadId - The ID of the thread from which the message will be deleted.
* @param {string} messageId - The ID of the message to be deleted.
* @returns {Promise<void>} A promise that resolves when the message has been successfully deleted.
*/
deleteMessage(threadId: string, messageId: string): Promise<void>
} }

View File

@ -11,15 +11,23 @@ export interface ThreadInterface {
* @abstract * @abstract
* @returns {Promise<Thread[]>} A promise that resolves to an array of threads. * @returns {Promise<Thread[]>} A promise that resolves to an array of threads.
*/ */
getThreads(): Promise<Thread[]> listThreads(): Promise<Thread[]>
/** /**
* Saves a thread. * Create a thread.
* @abstract * @abstract
* @param {Thread} thread - The thread to save. * @param {Thread} thread - The thread to save.
* @returns {Promise<void>} A promise that resolves when the thread is saved. * @returns {Promise<void>} A promise that resolves when the thread is saved.
*/ */
saveThread(thread: Thread): Promise<void> createThread(thread: Thread): Promise<Thread>
/**
* modify a thread.
* @abstract
* @param {Thread} thread - The thread to save.
* @returns {Promise<void>} A promise that resolves when the thread is saved.
*/
modifyThread(thread: Thread): Promise<void>
/** /**
* Deletes a thread. * Deletes a thread.

View File

@ -18,12 +18,14 @@
"devDependencies": { "devDependencies": {
"cpx": "^1.5.0", "cpx": "^1.5.0",
"rimraf": "^3.0.2", "rimraf": "^3.0.2",
"ts-loader": "^9.5.0",
"webpack": "^5.88.2", "webpack": "^5.88.2",
"webpack-cli": "^5.1.4", "webpack-cli": "^5.1.4"
"ts-loader": "^9.5.0"
}, },
"dependencies": { "dependencies": {
"@janhq/core": "file:../../core" "@janhq/core": "file:../../core",
"ky": "^1.7.2",
"p-queue": "^8.0.1"
}, },
"engines": { "engines": {
"node": ">=18.0.0" "node": ">=18.0.0"

View File

@ -0,0 +1,14 @@
export {}
declare global {
declare const API_URL: string
declare const SOCKET_URL: string
interface Core {
api: APIFunctions
events: EventEmitter
}
interface Window {
core?: Core | undefined
electronAPI?: any | undefined
}
}

View File

@ -1,408 +0,0 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
writeFileSync: jest.fn(),
readdirSync: jest.fn(),
readFileSync: jest.fn(),
appendFileSync: jest.fn(),
rm: jest.fn(),
writeBlob: jest.fn(),
joinPath: jest.fn(),
fileStat: jest.fn(),
},
joinPath: jest.fn(),
ConversationalExtension: jest.fn(),
}))
import { fs } from '@janhq/core'
import JSONConversationalExtension from '.'
describe('JSONConversationalExtension Tests', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should write messages', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const messages = [{ id: '1', thread_id: '1', content: [] }] as any
await extension.writeMessages('1', messages)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should get all messages on string response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should get all messages on object response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: 1 }])
})
it('get all messages return empty on error', async () => {
jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('return empty messages on no messages file', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('should ignore error message', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new image message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new pdf message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
})
describe('test readThread', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should read thread', async () => {
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue(JSON.stringify({ id: '1' }))
const thread = await extension.readThread('1')
expect(thread).toEqual(`{"id":"1"}`)
})
it('getValidThreadDirs should return valid thread directories', async () => {
jest
.spyOn(fs, 'readdirSync')
.mockResolvedValueOnce(['1', '2', '3'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce([])
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
jest.spyOn(fs, 'fileStat').mockResolvedValue({
isDirectory: true,
} as any)
const validThreadDirs = await extension.getValidThreadDirs()
expect(validThreadDirs).toEqual(['1', '2'])
})
})

View File

@ -1,90 +1,71 @@
import { import {
fs,
joinPath,
ConversationalExtension, ConversationalExtension,
Thread, Thread,
ThreadAssistantInfo,
ThreadMessage, ThreadMessage,
} from '@janhq/core' } from '@janhq/core'
import { safelyParseJSON } from './jsonUtil' import ky from 'ky'
import PQueue from 'p-queue'
type ThreadList = {
data: Thread[]
}
type MessageList = {
data: ThreadMessage[]
}
/** /**
* JSONConversationalExtension is a ConversationalExtension implementation that provides * JSONConversationalExtension is a ConversationalExtension implementation that provides
* functionality for managing threads. * functionality for managing threads.
*/ */
export default class JSONConversationalExtension extends ConversationalExtension { export default class JSONConversationalExtension extends ConversationalExtension {
private static readonly _threadFolder = 'file://threads' queue = new PQueue({ concurrency: 1 })
private static readonly _threadInfoFileName = 'thread.json'
private static readonly _threadMessagesFileName = 'messages.jsonl'
/** /**
* Called when the extension is loaded. * Called when the extension is loaded.
*/ */
async onLoad() { async onLoad() {
if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) { this.queue.add(() => this.healthz())
await fs.mkdir(JSONConversationalExtension._threadFolder)
}
} }
/** /**
* Called when the extension is unloaded. * Called when the extension is unloaded.
*/ */
onUnload() { onUnload() {}
console.debug('JSONConversationalExtension unloaded')
}
/** /**
* Returns a Promise that resolves to an array of Conversation objects. * Returns a Promise that resolves to an array of Conversation objects.
*/ */
async getThreads(): Promise<Thread[]> { async listThreads(): Promise<Thread[]> {
try { return this.queue.add(() =>
const threadDirs = await this.getValidThreadDirs() ky
.get(`${API_URL}/v1/threads`)
const promises = threadDirs.map((dirName) => this.readThread(dirName)) .json<ThreadList>()
const promiseResults = await Promise.allSettled(promises) .then((e) => e.data)
const convos = promiseResults ) as Promise<Thread[]>
.map((result) => {
if (result.status === 'fulfilled') {
return typeof result.value === 'object'
? result.value
: safelyParseJSON(result.value)
}
return undefined
})
.filter((convo) => !!convo)
convos.sort(
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
)
return convos
} catch (error) {
console.error(error)
return []
}
} }
/** /**
* Saves a Thread object to a json file. * Saves a Thread object to a json file.
* @param thread The Thread object to save. * @param thread The Thread object to save.
*/ */
async saveThread(thread: Thread): Promise<void> { async createThread(thread: Thread): Promise<Thread> {
try { return this.queue.add(() =>
const threadDirPath = await joinPath([ ky.post(`${API_URL}/v1/threads`, { json: thread }).json<Thread>()
JSONConversationalExtension._threadFolder, ) as Promise<Thread>
thread.id,
])
const threadJsonPath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadInfoFileName,
])
if (!(await fs.existsSync(threadDirPath))) {
await fs.mkdir(threadDirPath)
} }
await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2)) /**
} catch (err) { * Saves a Thread object to a json file.
console.error(err) * @param thread The Thread object to save.
Promise.reject(err) */
} async modifyThread(thread: Thread): Promise<void> {
return this.queue
.add(() =>
ky.post(`${API_URL}/v1/threads/${thread.id}`, { json: thread })
)
.then()
} }
/** /**
@ -92,189 +73,126 @@ export default class JSONConversationalExtension extends ConversationalExtension
* @param threadId The ID of the thread to delete. * @param threadId The ID of the thread to delete.
*/ */
async deleteThread(threadId: string): Promise<void> { async deleteThread(threadId: string): Promise<void> {
const path = await joinPath([ return this.queue
JSONConversationalExtension._threadFolder, .add(() => ky.delete(`${API_URL}/v1/threads/${threadId}`))
`${threadId}`, .then()
])
try {
await fs.rm(path)
} catch (err) {
console.error(err)
}
}
async addNewMessage(message: ThreadMessage): Promise<void> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
message.thread_id,
])
const threadMessagePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
if (message.content[0]?.type === 'image') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
const imagePath = await joinPath([filesPath, `${message.id}.png`])
const base64 = message.content[0].text.annotations[0]
await this.storeImage(base64, imagePath)
if ((await fs.existsSync(imagePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
}
}
if (message.content[0]?.type === 'pdf') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
const filePath = await joinPath([filesPath, `${message.id}.pdf`])
const blob = message.content[0].text.annotations[0]
await this.storeFile(blob, filePath)
if ((await fs.existsSync(filePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
}
}
await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n')
Promise.resolve()
} catch (err) {
Promise.reject(err)
}
}
async storeImage(base64: string, filePath: string): Promise<void> {
const base64Data = base64.replace(/^data:image\/\w+;base64,/, '')
try {
await fs.writeBlob(filePath, base64Data)
} catch (err) {
console.error(err)
}
}
async storeFile(base64: string, filePath: string): Promise<void> {
const base64Data = base64.replace(/^data:application\/pdf;base64,/, '')
try {
await fs.writeBlob(filePath, base64Data)
} catch (err) {
console.error(err)
}
}
async writeMessages(
threadId: string,
messages: ThreadMessage[]
): Promise<void> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
threadId,
])
const threadMessagePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
await fs.writeFileSync(
threadMessagePath,
messages.map((msg) => JSON.stringify(msg)).join('\n') +
(messages.length ? '\n' : '')
)
Promise.resolve()
} catch (err) {
Promise.reject(err)
}
} }
/** /**
* A promise builder for reading a thread from a file. * Adds a new message to a specified thread.
* @param threadDirName the thread dir we are reading from. * @param message The ThreadMessage object to be added.
* @returns data of the thread * @returns A Promise that resolves when the message has been added.
*/ */
async readThread(threadDirName: string): Promise<any> { async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
return fs.readFileSync( return this.queue.add(() =>
await joinPath([ ky
JSONConversationalExtension._threadFolder, .post(`${API_URL}/v1/threads/${message.thread_id}/messages`, {
threadDirName, json: message,
JSONConversationalExtension._threadInfoFileName,
]),
'utf-8'
)
}
/**
* Returns a Promise that resolves to an array of thread directories.
* @private
*/
async getValidThreadDirs(): Promise<string[]> {
const fileInsideThread: string[] = await fs.readdirSync(
JSONConversationalExtension._threadFolder
)
const threadDirs: string[] = []
for (let i = 0; i < fileInsideThread.length; i++) {
const path = await joinPath([
JSONConversationalExtension._threadFolder,
fileInsideThread[i],
])
if (!(await fs.fileStat(path))?.isDirectory) continue
const isHavingThreadInfo = (await fs.readdirSync(path)).includes(
JSONConversationalExtension._threadInfoFileName
)
if (!isHavingThreadInfo) {
console.debug(`Ignore ${path} because it does not have thread info`)
continue
}
threadDirs.push(fileInsideThread[i])
}
return threadDirs
}
async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
threadId,
])
const files: string[] = await fs.readdirSync(threadDirPath)
if (
!files.includes(JSONConversationalExtension._threadMessagesFileName)
) {
console.debug(`${threadDirPath} not contains message file`)
return []
}
const messageFilePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
let readResult = await fs.readFileSync(messageFilePath, 'utf-8')
if (typeof readResult === 'object') {
readResult = JSON.stringify(readResult)
}
const result = readResult.split('\n').filter((line) => line !== '')
const messages: ThreadMessage[] = []
result.forEach((line: string) => {
const message = safelyParseJSON(line)
if (message) messages.push(safelyParseJSON(line))
}) })
return messages .json<ThreadMessage>()
} catch (err) { ) as Promise<ThreadMessage>
console.error(err)
return []
} }
/**
* Modifies a message in a thread.
* @param message
* @returns
*/
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
return this.queue.add(() =>
ky
.post(
`${API_URL}/v1/threads/${message.thread_id}/messages/${message.id}`,
{
json: message,
}
)
.json<ThreadMessage>()
) as Promise<ThreadMessage>
}
/**
* Deletes a specific message from a thread.
* @param threadId The ID of the thread containing the message.
* @param messageId The ID of the message to be deleted.
* @returns A Promise that resolves when the message has been successfully deleted.
*/
async deleteMessage(threadId: string, messageId: string): Promise<void> {
return this.queue
.add(() =>
ky.delete(`${API_URL}/v1/threads/${threadId}/messages/${messageId}`)
)
.then()
}
/**
* Retrieves all messages for a specified thread.
* @param threadId The ID of the thread to get messages from.
* @returns A Promise that resolves to an array of ThreadMessage objects.
*/
async listMessages(threadId: string): Promise<ThreadMessage[]> {
return this.queue.add(() =>
ky
.get(`${API_URL}/v1/threads/${threadId}/messages?order=asc`)
.json<MessageList>()
.then((e) => e.data)
) as Promise<ThreadMessage[]>
}
/**
* Retrieves the assistant information for a specified thread.
* @param threadId The ID of the thread for which to retrieve assistant information.
* @returns A Promise that resolves to a ThreadAssistantInfo object containing
* the details of the assistant associated with the specified thread.
*/
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky.get(`${API_URL}/v1/assistants/${threadId}`).json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Creates a new assistant for the specified thread.
* @param threadId The ID of the thread for which the assistant is being created.
* @param assistant The information about the assistant to be created.
* @returns A Promise that resolves to the newly created ThreadAssistantInfo object.
*/
async createThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky
.post(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
.json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Modifies an existing assistant for the specified thread.
* @param threadId The ID of the thread for which the assistant is being modified.
* @param assistant The updated information for the assistant.
* @returns A Promise that resolves to the updated ThreadAssistantInfo object.
*/
async modifyThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky
.patch(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
.json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Do health check on cortex.cpp
* @returns
*/
healthz(): Promise<void> {
return ky
.get(`${API_URL}/healthz`, {
retry: { limit: 20, delay: () => 500, methods: ['get'] },
})
.then(() => {})
} }
} }

View File

@ -1,14 +0,0 @@
// Note about performance
// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
// v8 4.5 and above can optimise try/catch
export function safelyParseJSON(json) {
// This function cannot be optimised, it's best to
// keep it small!
var parsed
try {
parsed = JSON.parse(json)
} catch (e) {
return undefined
}
return parsed // Could be undefined!
}

View File

@ -17,7 +17,12 @@ module.exports = {
filename: 'index.js', // Adjust the output file name as needed filename: 'index.js', // Adjust the output file name as needed
library: { type: 'module' }, // Specify ESM output format library: { type: 'module' }, // Specify ESM output format
}, },
plugins: [new webpack.DefinePlugin({})], plugins: [
new webpack.DefinePlugin({
API_URL: JSON.stringify('http://127.0.0.1:39291'),
SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
}),
],
resolve: { resolve: {
extensions: ['.ts', '.js'], extensions: ['.ts', '.js'],
}, },

View File

@ -18,14 +18,14 @@ import { isLocalEngine } from '@/utils/modelEngine'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const ErrorMessage = ({ message }: { message: ThreadMessage }) => { const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
const setMainState = useSetAtom(mainViewStateAtom) const setMainState = useSetAtom(mainViewStateAtom)
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeAssistant = useAtomValue(activeAssistantAtom)
const defaultDesc = () => { const defaultDesc = () => {
return ( return (

View File

@ -46,6 +46,7 @@ import {
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
@ -75,6 +76,7 @@ const ModelDropdown = ({
const [searchText, setSearchText] = useState('') const [searchText, setSearchText] = useState('')
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const downloadingModels = useAtomValue(getDownloadingModelAtom) const downloadingModels = useAtomValue(getDownloadingModelAtom)
const [toggle, setToggle] = useState<HTMLDivElement | null>(null) const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
@ -151,17 +153,24 @@ const ModelDropdown = ({
useEffect(() => { useEffect(() => {
if (!activeThread) return if (!activeThread) return
const modelId = activeThread?.assistants?.[0]?.model?.id const modelId = activeAssistant?.model?.id
let model = downloadedModels.find((model) => model.id === modelId) let model = downloadedModels.find((model) => model.id === modelId)
if (!model) { if (!model) {
model = recommendedModel model = recommendedModel
} }
setSelectedModel(model) setSelectedModel(model)
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) }, [
recommendedModel,
activeThread,
downloadedModels,
setSelectedModel,
activeAssistant?.model?.id,
])
const onClickModelItem = useCallback( const onClickModelItem = useCallback(
async (modelId: string) => { async (modelId: string) => {
if (!activeAssistant) return
const model = downloadedModels.find((m) => m.id === modelId) const model = downloadedModels.find((m) => m.id === modelId)
setSelectedModel(model) setSelectedModel(model)
setOpen(false) setOpen(false)
@ -172,14 +181,14 @@ const ModelDropdown = ({
...activeThread, ...activeThread,
assistants: [ assistants: [
{ {
...activeThread.assistants[0], ...activeAssistant,
tools: [ tools: [
{ {
type: 'retrieval', type: 'retrieval',
enabled: isModelSupportRagAndTools(model as Model), enabled: isModelSupportRagAndTools(model as Model),
settings: { settings: {
...(activeThread.assistants[0].tools && ...(activeAssistant.tools &&
activeThread.assistants[0].tools[0]?.settings), activeAssistant.tools[0]?.settings),
}, },
}, },
], ],
@ -215,13 +224,14 @@ const ModelDropdown = ({
} }
}, },
[ [
activeAssistant,
downloadedModels, downloadedModels,
activeThread,
setSelectedModel, setSelectedModel,
activeThread,
updateThreadMetadata,
isModelSupportRagAndTools, isModelSupportRagAndTools,
setThreadModelParams, setThreadModelParams,
updateModelParameter, updateModelParameter,
updateThreadMetadata,
] ]
) )

View File

@ -1,4 +1,4 @@
import { Fragment, useCallback, useEffect, useRef } from 'react' import { Fragment, use, useCallback, useEffect, useRef } from 'react'
import { import {
ChatCompletionMessage, ChatCompletionMessage,
@ -31,6 +31,7 @@ import {
addNewMessageAtom, addNewMessageAtom,
updateMessageAtom, updateMessageAtom,
tokenSpeedAtom, tokenSpeedAtom,
deleteMessageAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { import {
@ -49,6 +50,7 @@ export default function ModelHandler() {
const addNewMessage = useSetAtom(addNewMessageAtom) const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom) const updateMessage = useSetAtom(updateMessageAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom)
const deleteMessage = useSetAtom(deleteMessageAtom)
const activeModel = useAtomValue(activeModelAtom) const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom) const setStateModel = useSetAtom(stateModelAtom)
@ -86,7 +88,7 @@ export default function ModelHandler() {
}, [activeModelParams]) }, [activeModelParams])
const onNewMessageResponse = useCallback( const onNewMessageResponse = useCallback(
(message: ThreadMessage) => { async (message: ThreadMessage) => {
if (message.type === MessageRequestType.Thread) { if (message.type === MessageRequestType.Thread) {
addNewMessage(message) addNewMessage(message)
} }
@ -154,12 +156,15 @@ export default function ModelHandler() {
...thread, ...thread,
title: cleanedMessageContent, title: cleanedMessageContent,
metadata: thread.metadata, metadata: {
...thread.metadata,
title: cleanedMessageContent,
},
} }
extensionManager extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({ ?.modifyThread({
...updatedThread, ...updatedThread,
}) })
.then(() => { .then(() => {
@ -233,7 +238,9 @@ export default function ModelHandler() {
const thread = threadsRef.current?.find((e) => e.id == message.thread_id) const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
if (!thread) return if (!thread) return
const messageContent = message.content[0]?.text?.value const messageContent = message.content[0]?.text?.value
const metadata = { const metadata = {
...thread.metadata, ...thread.metadata,
...(messageContent && { lastMessage: messageContent }), ...(messageContent && { lastMessage: messageContent }),
@ -246,15 +253,19 @@ export default function ModelHandler() {
extensionManager extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({ ?.modifyThread({
...thread, ...thread,
metadata, metadata,
}) })
;(async () => {
// If this is not the summary of the Thread, don't need to add it to the Thread const updatedMessage = await extensionManager
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(message) ?.createMessage(message)
if (updatedMessage) {
deleteMessage(message.id)
addNewMessage(updatedMessage)
}
})()
// Attempt to generate the title of the Thread when needed // Attempt to generate the title of the Thread when needed
generateThreadTitle(message, thread) generateThreadTitle(message, thread)
@ -279,7 +290,9 @@ export default function ModelHandler() {
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
// If this is the first ever prompt in the thread // If this is the first ever prompt in the thread
if (thread.title?.trim() !== defaultThreadTitle) { if (
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
) {
return return
} }
@ -292,11 +305,14 @@ export default function ModelHandler() {
const updatedThread: Thread = { const updatedThread: Thread = {
...thread, ...thread,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle, title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
metadata: thread.metadata, metadata: {
...thread.metadata,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
},
} }
return extensionManager return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({ ?.modifyThread({
...updatedThread, ...updatedThread,
}) })
.then(() => { .then(() => {

View File

@ -1,4 +1,12 @@
import { Assistant } from '@janhq/core' import { Assistant, ThreadAssistantInfo } from '@janhq/core'
import { atom } from 'jotai' import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
export const assistantsAtom = atom<Assistant[]>([]) export const assistantsAtom = atom<Assistant[]>([])
/**
* Get the current active assistant
*/
export const activeAssistantAtom = atomWithStorage<
ThreadAssistantInfo | undefined
>('activeAssistant', undefined, undefined, { getOnInit: true })

View File

@ -6,6 +6,8 @@ import {
} from '@janhq/core' } from '@janhq/core'
import { atom } from 'jotai' import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
import { import {
getActiveThreadIdAtom, getActiveThreadIdAtom,
updateThreadStateLastMessageAtom, updateThreadStateLastMessageAtom,
@ -13,15 +15,23 @@ import {
import { TokenSpeed } from '@/types/token' import { TokenSpeed } from '@/types/token'
const CHAT_MESSAGE_NAME = 'chatMessages'
/** /**
* Stores all chat messages for all threads * Stores all chat messages for all threads
*/ */
export const chatMessages = atom<Record<string, ThreadMessage[]>>({}) export const chatMessages = atomWithStorage<Record<string, ThreadMessage[]>>(
CHAT_MESSAGE_NAME,
{},
undefined,
{ getOnInit: true }
)
/** /**
* Stores the status of the messages load for each thread * Stores the status of the messages load for each thread
*/ */
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({}) export const readyThreadsMessagesAtom = atomWithStorage<
Record<string, boolean>
>('currentThreadMessages', {}, undefined, { getOnInit: true })
/** /**
* Store the token speed for current message * Store the token speed for current message
@ -34,6 +44,7 @@ export const getCurrentChatMessagesAtom = atom<ThreadMessage[]>((get) => {
const activeThreadId = get(getActiveThreadIdAtom) const activeThreadId = get(getActiveThreadIdAtom)
if (!activeThreadId) return [] if (!activeThreadId) return []
const messages = get(chatMessages)[activeThreadId] const messages = get(chatMessages)[activeThreadId]
if (!Array.isArray(messages)) return []
return messages ?? [] return messages ?? []
}) })

View File

@ -8,6 +8,7 @@ import { toaster } from '@/containers/Toast'
import { LAST_USED_MODEL_ID } from './useRecommendedModel' import { LAST_USED_MODEL_ID } from './useRecommendedModel'
import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -34,6 +35,7 @@ export function useActiveModel() {
const setLoadModelError = useSetAtom(loadModelErrorAtom) const setLoadModelError = useSetAtom(loadModelErrorAtom)
const pendingModelLoad = useRef(false) const pendingModelLoad = useRef(false)
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const downloadedModelsRef = useRef<Model[]>([]) const downloadedModelsRef = useRef<Model[]>([])
@ -79,12 +81,12 @@ export function useActiveModel() {
} }
/// Apply thread model settings /// Apply thread model settings
if (activeThread?.assistants[0]?.model.id === modelId) { if (activeAssistant?.model.id === modelId) {
model = { model = {
...model, ...model,
settings: { settings: {
...model.settings, ...model.settings,
...activeThread.assistants[0].model.settings, ...activeAssistant?.model.settings,
}, },
} }
} }

View File

@ -1,7 +1,6 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { import {
Assistant,
ConversationalExtension, ConversationalExtension,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
@ -9,16 +8,17 @@ import {
ThreadState, ThreadState,
AssistantTool, AssistantTool,
Model, Model,
Assistant,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction' import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai' import { fileUploadAtom } from '@/containers/Providers/Jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
import { generateThreadId } from '@/utils/thread'
import { useActiveModel } from './useActiveModel' import { useActiveModel } from './useActiveModel'
import useRecommendedModel from './useRecommendedModel' import useRecommendedModel from './useRecommendedModel'
@ -27,6 +27,7 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
threadsAtom, threadsAtom,
@ -34,7 +35,6 @@ import {
updateThreadAtom, updateThreadAtom,
setThreadModelParamsAtom, setThreadModelParamsAtom,
isGeneratingResponseAtom, isGeneratingResponseAtom,
activeThreadAtom,
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
@ -64,7 +64,7 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue( const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom copyOverInstructionEnabledAtom
) )
const activeThread = useAtomValue(activeThreadAtom) const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
@ -75,7 +75,7 @@ export const useCreateNewThread = () => {
const { stopInference } = useActiveModel() const { stopInference } = useActiveModel()
const requestCreateNewThread = async ( const requestCreateNewThread = async (
assistant: Assistant, assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant,
model?: Model | undefined model?: Model | undefined
) => { ) => {
// Stop generating if any // Stop generating if any
@ -124,7 +124,7 @@ export const useCreateNewThread = () => {
const createdAt = Date.now() const createdAt = Date.now()
let instructions: string | undefined = assistant.instructions let instructions: string | undefined = assistant.instructions
if (copyOverInstructionEnabled) { if (copyOverInstructionEnabled) {
instructions = activeThread?.assistants[0]?.instructions ?? undefined instructions = activeAssistant?.instructions ?? undefined
} }
const assistantInfo: ThreadAssistantInfo = { const assistantInfo: ThreadAssistantInfo = {
assistant_id: assistant.id, assistant_id: assistant.id,
@ -139,22 +139,26 @@ export const useCreateNewThread = () => {
instructions, instructions,
} }
const threadId = generateThreadId(assistant.id) const thread: Partial<Thread> = {
const thread: Thread = {
id: threadId,
object: 'thread', object: 'thread',
title: 'New Thread', title: 'New Thread',
assistants: [assistantInfo], assistants: [assistantInfo],
created: createdAt, created: createdAt,
updated: createdAt, updated: createdAt,
metadata: {
title: 'New Thread',
},
} }
// add the new thread on top of the thread list to the state // add the new thread on top of the thread list to the state
//TODO: Why do we have thread list then thread states? Should combine them //TODO: Why do we have thread list then thread states? Should combine them
createNewThread(thread) try {
const createdThread = await persistNewThread(thread, assistantInfo)
if (!createdThread) throw 'Thread creation failed'
createNewThread(createdThread)
setSelectedModel(defaultModel) setSelectedModel(defaultModel)
setThreadModelParams(thread.id, { setThreadModelParams(createdThread.id, {
...defaultModel?.settings, ...defaultModel?.settings,
...defaultModel?.parameters, ...defaultModel?.parameters,
...overriddenSettings, ...overriddenSettings,
@ -162,22 +166,67 @@ export const useCreateNewThread = () => {
// Delete the file upload state // Delete the file upload state
setFileUpload([]) setFileUpload([])
// Update thread metadata setActiveThread(createdThread)
await updateThreadMetadata(thread) } catch (ex) {
return toaster({
setActiveThread(thread) title: 'Thread created failed.',
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
type: 'error',
})
} }
}
const updateThreadExtension = (thread: Thread) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread(thread)
}
const updateAssistantExtension = (
threadId: string,
assistant: ThreadAssistantInfo
) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThreadAssistant(threadId, assistant)
}
const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300)
const updateAssistantCallback = useDebouncedCallback(
updateAssistantExtension,
300
)
const updateThreadMetadata = useCallback( const updateThreadMetadata = useCallback(
async (thread: Thread) => { async (thread: Thread) => {
updateThread(thread) updateThread(thread)
setActiveAssistant(thread.assistants[0])
updateThreadCallback(thread)
updateAssistantCallback(thread.id, thread.assistants[0])
},
[
updateThread,
setActiveAssistant,
updateThreadCallback,
updateAssistantCallback,
]
)
const persistNewThread = async (
thread: Partial<Thread>,
assistantInfo: ThreadAssistantInfo
): Promise<Thread | undefined> => {
return await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createThread(thread)
.then(async (thread) => {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(thread) ?.createThreadAssistant(thread.id, assistantInfo)
}, return thread
[updateThread] })
) }
return { return {
requestCreateNewThread, requestCreateNewThread,

View File

@ -1,13 +1,6 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core'
ChatCompletionRole,
ExtensionTypeEnum,
ConversationalExtension,
fs,
joinPath,
Thread,
} from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -15,89 +8,63 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
import { useCreateNewThread } from './useCreateNewThread'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
chatMessages, import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
cleanChatMessageAtom as cleanChatMessagesAtom,
deleteChatMessageAtom as deleteChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { import {
threadsAtom, threadsAtom,
setActiveThreadIdAtom, setActiveThreadIdAtom,
deleteThreadStateAtom, deleteThreadStateAtom,
updateThreadStateLastMessageAtom,
updateThreadAtom,
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
export default function useDeleteThread() { export default function useDeleteThread() {
const [threads, setThreads] = useAtom(threadsAtom) const [threads, setThreads] = useAtom(threadsAtom)
const messages = useAtomValue(chatMessages) const { requestCreateNewThread } = useCreateNewThread()
const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const assistants = useAtomValue(assistantsAtom)
const models = useAtomValue(downloadedModelsAtom)
const setCurrentPrompt = useSetAtom(currentPromptAtom) const setCurrentPrompt = useSetAtom(currentPromptAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteChatMessagesAtom) const deleteMessages = useSetAtom(deleteChatMessagesAtom)
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
const deleteThreadState = useSetAtom(deleteThreadStateAtom) const deleteThreadState = useSetAtom(deleteThreadStateAtom)
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const cleanThread = useCallback( const cleanThread = useCallback(
async (threadId: string) => { async (threadId: string) => {
cleanMessages(threadId)
const thread = threads.find((c) => c.id === threadId) const thread = threads.find((c) => c.id === threadId)
if (!thread) return if (!thread) return
const assistantInfo = await extensionManager
const updatedMessages = (messages[threadId] ?? []).filter(
(msg) => msg.role === ChatCompletionRole.System
)
// remove files
try {
const threadFolderPath = await joinPath([
janDataFolderPath,
'threads',
threadId,
])
const threadFilesPath = await joinPath([threadFolderPath, 'files'])
const threadMemoryPath = await joinPath([threadFolderPath, 'memory'])
await fs.rm(threadFilesPath)
await fs.rm(threadMemoryPath)
} catch (err) {
console.warn('Error deleting thread files', err)
}
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(threadId, updatedMessages) ?.getThreadAssistant(thread.id)
thread.metadata = { if (!assistantInfo) return
...thread.metadata, const model = models.find((c) => c.id === assistantInfo?.model?.id)
}
const updatedThread: Thread = { requestCreateNewThread(
...thread, {
title: 'New Thread', ...assistantInfo,
metadata: { ...thread.metadata, lastMessage: undefined }, id: assistants[0].id,
} name: assistants[0].name,
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
updateThreadLastMessage(threadId, undefined)
updateThread(updatedThread)
}, },
[ model
cleanMessages, ? {
threads, ...model,
messages, parameters: assistantInfo?.model?.parameters ?? {},
updateThreadLastMessage, settings: assistantInfo?.model?.settings ?? {},
updateThread, }
janDataFolderPath, : undefined
] )
// Delete this thread
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteThread(threadId)
.catch(console.error)
},
[assistants, models, requestCreateNewThread, threads]
) )
const deleteThread = async (threadId: string) => { const deleteThread = async (threadId: string) => {
@ -105,10 +72,10 @@ export default function useDeleteThread() {
alert('No active thread') alert('No active thread')
return return
} }
try {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteThread(threadId) ?.deleteThread(threadId)
.catch(console.error)
const availableThreads = threads.filter((c) => c.id !== threadId) const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads) setThreads(availableThreads)
@ -127,9 +94,6 @@ export default function useDeleteThread() {
} else { } else {
setActiveThreadId(undefined) setActiveThreadId(undefined)
} }
} catch (err) {
console.error(err)
}
} }
return { return {

View File

@ -2,6 +2,7 @@ import { openFileExplorer, joinPath, baseName } from '@janhq/core'
import { useAtomValue } from 'jotai' import { useAtomValue } from 'jotai'
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -9,13 +10,14 @@ export const usePath = () => {
const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const selectedModel = useAtomValue(selectedModelAtom) const selectedModel = useAtomValue(selectedModelAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const onRevealInFinder = async (type: string) => { const onRevealInFinder = async (type: string) => {
// TODO: this logic should be refactored. // TODO: this logic should be refactored.
if (type !== 'Model' && !activeThread) return if (type !== 'Model' && !activeThread) return
let filePath = undefined let filePath = undefined
const assistantId = activeThread?.assistants[0]?.assistant_id const assistantId = activeAssistant?.assistant_id
switch (type) { switch (type) {
case 'Engine': case 'Engine':
case 'Thread': case 'Thread':

View File

@ -6,6 +6,7 @@ import { atom, useAtomValue } from 'jotai'
import { activeModelAtom } from './useActiveModel' import { activeModelAtom } from './useActiveModel'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -28,6 +29,7 @@ export default function useRecommendedModel() {
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>() const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => { const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
const models = downloadedModels.sort((a, b) => const models = downloadedModels.sort((a, b) =>
@ -45,8 +47,8 @@ export default function useRecommendedModel() {
> => { > => {
const models = await getAndSortDownloadedModels() const models = await getAndSortDownloadedModels()
if (!activeThread) return if (!activeThread || !activeAssistant) return
const modelId = activeThread.assistants[0]?.model.id const modelId = activeAssistant.model.id
const model = models.find((model) => model.id === modelId) const model = models.find((model) => model.id === modelId)
if (model) { if (model) {

View File

@ -10,6 +10,7 @@ import {
ConversationalExtension, ConversationalExtension,
EngineManager, EngineManager,
ToolManager, ToolManager,
ThreadAssistantInfo,
} from '@janhq/core' } from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -28,6 +29,7 @@ import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel' import { useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { import {
addNewMessageAtom, addNewMessageAtom,
deleteMessageAtom, deleteMessageAtom,
@ -48,6 +50,7 @@ export const reloadModelAtom = atom(false)
export default function useSendChatMessage() { export default function useSendChatMessage() {
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const addNewMessage = useSetAtom(addNewMessageAtom) const addNewMessage = useSetAtom(addNewMessageAtom)
const updateThread = useSetAtom(updateThreadAtom) const updateThread = useSetAtom(updateThreadAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
@ -68,6 +71,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>() const activeThreadRef = useRef<Thread | undefined>()
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
const setTokenSpeed = useSetAtom(tokenSpeedAtom) const setTokenSpeed = useSetAtom(tokenSpeedAtom)
const selectedModelRef = useRef<Model | undefined>() const selectedModelRef = useRef<Model | undefined>()
@ -84,36 +88,37 @@ export default function useSendChatMessage() {
selectedModelRef.current = selectedModel selectedModelRef.current = selectedModel
}, [selectedModel]) }, [selectedModel])
const resendChatMessage = async (currentMessage: ThreadMessage) => { useEffect(() => {
activeAssistantRef.current = activeAssistant
}, [activeAssistant])
const resendChatMessage = async () => {
// Delete last response before regenerating // Delete last response before regenerating
const newConvoData = currentMessages const newConvoData = Array.from(currentMessages)
let toSendMessage = currentMessage let toSendMessage = newConvoData.pop()
do { while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) {
deleteMessage(currentMessage.id)
const msg = newConvoData.pop()
if (!msg) break
toSendMessage = msg
deleteMessage(toSendMessage.id ?? '')
} while (toSendMessage.role !== ChatCompletionRole.User)
if (activeThreadRef.current) {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(activeThreadRef.current.id, newConvoData) ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
.catch(console.error)
deleteMessage(toSendMessage.id ?? '')
toSendMessage = newConvoData.pop()
} }
sendChatMessage(toSendMessage.content[0]?.text.value) if (toSendMessage?.content[0]?.text?.value)
sendChatMessage(toSendMessage.content[0].text.value, true)
} }
const sendChatMessage = async ( const sendChatMessage = async (
message: string, message: string,
isResend: boolean = false,
messages?: ThreadMessage[] messages?: ThreadMessage[]
) => { ) => {
if (!message || message.trim().length === 0) return if (!message || message.trim().length === 0) return
if (!activeThreadRef.current) { if (!activeThreadRef.current || !activeAssistantRef.current) {
console.error('No active thread') console.error('No active thread or assistant')
return return
} }
@ -139,11 +144,11 @@ export default function useSendChatMessage() {
} }
const modelRequest = const modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model selectedModelRef?.current ?? activeAssistantRef.current?.model
// Fallback support for previous broken threads // Fallback support for previous broken threads
if (activeThreadRef.current?.assistants[0]?.model?.id === '*') { if (activeAssistantRef.current?.model?.id === '*') {
activeThreadRef.current.assistants[0].model = { activeAssistantRef.current.model = {
id: modelRequest.id, id: modelRequest.id,
settings: modelRequest.settings, settings: modelRequest.settings,
parameters: modelRequest.parameters, parameters: modelRequest.parameters,
@ -163,8 +168,9 @@ export default function useSendChatMessage() {
}, },
activeThreadRef.current, activeThreadRef.current,
messages ?? currentMessages messages ?? currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0].instructions) ).addSystemMessage(activeAssistantRef.current?.instructions)
if (!isResend) {
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type) requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
// Build Thread Message to persist // Build Thread Message to persist
@ -174,9 +180,6 @@ export default function useSendChatMessage() {
const newMessage = threadMessageBuilder.build() const newMessage = threadMessageBuilder.build()
// Push to states
addNewMessage(newMessage)
// Update thread state // Update thread state
const updatedThread: Thread = { const updatedThread: Thread = {
...activeThreadRef.current, ...activeThreadRef.current,
@ -189,20 +192,25 @@ export default function useSendChatMessage() {
updateThread(updatedThread) updateThread(updatedThread)
// Add message // Add message
await extensionManager const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(newMessage) ?.createMessage(newMessage)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started // Start Model if not started
const modelId = const modelId =
selectedModelRef.current?.id ?? selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
activeThreadRef.current.assistants[0].model.id
if (base64Blob) { if (base64Blob) {
setFileUpload([]) setFileUpload([])
} }
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error) const error = await startModel(modelId).catch((error: Error) => error)
if (error) { if (error) {
updateThreadWaiting(activeThreadRef.current.id, false) updateThreadWaiting(activeThreadRef.current.id, false)

View File

@ -1,12 +1,10 @@
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core' import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
readyThreadsMessagesAtom, import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
setConvoMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { import {
setActiveThreadIdAtom, setActiveThreadIdAtom,
setThreadModelParamsAtom, setThreadModelParamsAtom,
@ -17,21 +15,27 @@ export default function useSetActiveThread() {
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const setThreadMessage = useSetAtom(setConvoMessagesAtom) const setThreadMessage = useSetAtom(setConvoMessagesAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom) const setActiveAssistant = useSetAtom(activeAssistantAtom)
const setActiveThread = async (thread: Thread) => { const setActiveThread = async (thread: Thread) => {
// Load local messages only if there are no messages in the state if (!thread?.id) return
if (!readyMessageThreads[thread?.id]) {
const messages = await getLocalThreadMessage(thread?.id)
setThreadMessage(thread?.id, messages)
}
setActiveThreadId(thread?.id) setActiveThreadId(thread?.id)
try {
const assistantInfo = await getThreadAssistant(thread.id)
setActiveAssistant(assistantInfo)
// Load local messages only if there are no messages in the state
const messages = await getLocalThreadMessage(thread.id).catch(() => [])
const modelParams: ModelParams = { const modelParams: ModelParams = {
...thread?.assistants[0]?.model?.parameters, ...assistantInfo?.model?.parameters,
...thread?.assistants[0]?.model?.settings, ...assistantInfo?.model?.settings,
} }
setThreadModelParams(thread?.id, modelParams) setThreadModelParams(thread?.id, modelParams)
setThreadMessage(thread.id, messages)
} catch (e) {
console.error(e)
}
} }
return { setActiveThread } return { setActiveThread }
@ -40,4 +44,9 @@ export default function useSetActiveThread() {
const getLocalThreadMessage = async (threadId: string) => const getLocalThreadMessage = async (threadId: string) =>
extensionManager extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getAllMessages(threadId) ?? [] ?.listMessages(threadId) ?? []
const getThreadAssistant = async (threadId: string) =>
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getThreadAssistant(threadId)

View File

@ -68,6 +68,6 @@ const useThreads = () => {
const getLocalThreads = async (): Promise<Thread[]> => const getLocalThreads = async (): Promise<Thread[]> =>
(await extensionManager (await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getThreads()) ?? [] ?.listThreads()) ?? []
export default useThreads export default useThreads

View File

@ -12,7 +12,10 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
getActiveThreadModelParamsAtom, getActiveThreadModelParamsAtom,
@ -29,11 +32,28 @@ export type UpdateModelParameter = {
export default function useUpdateModelParameters() { export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
const [selectedModel] = useAtom(selectedModelAtom) const [selectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateAssistantExtension = (
threadId: string,
assistant: ThreadAssistantInfo
) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThreadAssistant(threadId, assistant)
}
const updateAssistantCallback = useDebouncedCallback(
updateAssistantExtension,
300
)
const updateModelParameter = useCallback( const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => { async (thread: Thread, settings: UpdateModelParameter) => {
if (!activeAssistant) return
const toUpdateSettings = processStopWords(settings.params ?? {}) const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId const updatedModelParams = settings.modelId
? toUpdateSettings ? toUpdateSettings
@ -48,30 +68,33 @@ export default function useUpdateModelParameters() {
setThreadModelParams(thread.id, updatedModelParams) setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = extractInferenceParams(updatedModelParams) const runtimeParams = extractInferenceParams(updatedModelParams)
const settingParams = extractModelLoadParams(updatedModelParams) const settingParams = extractModelLoadParams(updatedModelParams)
const assistantInfo = {
const assistants = thread.assistants.map( ...activeAssistant,
(assistant: ThreadAssistantInfo) => { model: {
assistant.model.parameters = runtimeParams ...activeAssistant?.model,
assistant.model.settings = settingParams parameters: runtimeParams,
if (selectedModel) { settings: settingParams,
assistant.model.id = settings.modelId ?? selectedModel?.id id: settings.modelId ?? selectedModel?.id ?? activeAssistant.model.id,
assistant.model.engine = settings.engine ?? selectedModel?.engine engine:
} settings.engine ??
return assistant selectedModel?.engine ??
} activeAssistant.model.engine,
)
// update thread
const updatedThread: Thread = {
...thread,
assistants,
}
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
}, },
[activeModelParams, selectedModel, setThreadModelParams] }
setActiveAssistant(assistantInfo)
updateAssistantCallback(thread.id, assistantInfo)
},
[
activeAssistant,
selectedModel?.parameters,
selectedModel?.settings,
selectedModel?.id,
selectedModel?.engine,
activeModelParams,
setThreadModelParams,
setActiveAssistant,
updateAssistantCallback,
]
) )
const processStopWords = (params: ModelParams): ModelParams => { const processStopWords = (params: ModelParams): ModelParams => {

View File

@ -8,6 +8,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent' import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { import {
activeThreadAtom, activeThreadAtom,
engineParamsUpdateAtom, engineParamsUpdateAtom,
@ -19,13 +20,14 @@ type Props = {
const AssistantSetting: React.FC<Props> = ({ componentData }) => { const AssistantSetting: React.FC<Props> = ({ componentData }) => {
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { updateThreadMetadata } = useCreateNewThread() const { updateThreadMetadata } = useCreateNewThread()
const { stopModel } = useActiveModel() const { stopModel } = useActiveModel()
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const onValueChanged = useCallback( const onValueChanged = useCallback(
(key: string, value: string | number | boolean | string[]) => { (key: string, value: string | number | boolean | string[]) => {
if (!activeThread) return if (!activeThread || !activeAssistant) return
const shouldReloadModel = const shouldReloadModel =
componentData.find((x) => x.key === key)?.requireModelReload ?? false componentData.find((x) => x.key === key)?.requireModelReload ?? false
if (shouldReloadModel) { if (shouldReloadModel) {
@ -34,40 +36,40 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
} }
if ( if (
activeThread.assistants[0].tools && activeAssistant?.tools &&
(key === 'chunk_overlap' || key === 'chunk_size') (key === 'chunk_overlap' || key === 'chunk_size')
) { ) {
if ( if (
activeThread.assistants[0].tools[0]?.settings?.chunk_size < activeAssistant.tools[0]?.settings?.chunk_size <
activeThread.assistants[0].tools[0]?.settings?.chunk_overlap activeAssistant.tools[0]?.settings?.chunk_overlap
) { ) {
activeThread.assistants[0].tools[0].settings.chunk_overlap = activeAssistant.tools[0].settings.chunk_overlap =
activeThread.assistants[0].tools[0].settings.chunk_size activeAssistant.tools[0].settings.chunk_size
} }
if ( if (
key === 'chunk_size' && key === 'chunk_size' &&
value < activeThread.assistants[0].tools[0].settings?.chunk_overlap value < activeAssistant.tools[0].settings?.chunk_overlap
) { ) {
activeThread.assistants[0].tools[0].settings.chunk_overlap = value activeAssistant.tools[0].settings.chunk_overlap = value
} else if ( } else if (
key === 'chunk_overlap' && key === 'chunk_overlap' &&
value > activeThread.assistants[0].tools[0].settings?.chunk_size value > activeAssistant.tools[0].settings?.chunk_size
) { ) {
activeThread.assistants[0].tools[0].settings.chunk_size = value activeAssistant.tools[0].settings.chunk_size = value
} }
} }
updateThreadMetadata({ updateThreadMetadata({
...activeThread, ...activeThread,
assistants: [ assistants: [
{ {
...activeThread.assistants[0], ...activeAssistant,
tools: [ tools: [
{ {
type: 'retrieval', type: 'retrieval',
enabled: true, enabled: true,
settings: { settings: {
...(activeThread.assistants[0].tools && ...(activeAssistant.tools &&
activeThread.assistants[0].tools[0]?.settings), activeAssistant.tools[0]?.settings),
[key]: value, [key]: value,
}, },
}, },
@ -77,6 +79,7 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
}) })
}, },
[ [
activeAssistant,
activeThread, activeThread,
componentData, componentData,
setEngineParamsUpdate, setEngineParamsUpdate,

View File

@ -33,6 +33,7 @@ import RichTextEditor from './RichTextEditor'
import { showRightPanelAtom } from '@/helpers/atoms/App.atom' import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom' import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
@ -67,6 +68,7 @@ const ChatInput = () => {
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
const threadStates = useAtomValue(threadStatesAtom) const threadStates = useAtomValue(threadStatesAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { stopInference } = useActiveModel() const { stopInference } = useActiveModel()
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
@ -153,9 +155,9 @@ const ChatInput = () => {
onClick={(e) => { onClick={(e) => {
if ( if (
fileUpload.length > 0 || fileUpload.length > 0 ||
(activeThread?.assistants[0].tools && (activeAssistant?.tools &&
!activeThread?.assistants[0].tools[0]?.enabled && !activeAssistant?.tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings?.vision_model) !activeAssistant?.model.settings?.vision_model)
) { ) {
e.stopPropagation() e.stopPropagation()
} else { } else {
@ -171,16 +173,15 @@ const ChatInput = () => {
} }
disabled={ disabled={
isModelSupportRagAndTools && isModelSupportRagAndTools &&
activeThread?.assistants[0].tools && activeAssistant?.tools &&
activeThread?.assistants[0].tools[0]?.enabled activeAssistant?.tools[0]?.enabled
} }
content={ content={
<> <>
{fileUpload.length > 0 || {fileUpload.length > 0 ||
(activeThread?.assistants[0].tools && (activeAssistant?.tools &&
!activeThread?.assistants[0].tools[0]?.enabled && !activeAssistant?.tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings !activeAssistant?.model.settings?.vision_model && (
?.vision_model && (
<> <>
{fileUpload.length !== 0 && ( {fileUpload.length !== 0 && (
<span> <span>
@ -188,9 +189,8 @@ const ChatInput = () => {
time. time.
</span> </span>
)} )}
{activeThread?.assistants[0].tools && {activeAssistant?.tools &&
activeThread?.assistants[0].tools[0]?.enabled === activeAssistant?.tools[0]?.enabled === false &&
false &&
isModelSupportRagAndTools && ( isModelSupportRagAndTools && (
<span> <span>
Turn on Retrieval in Tools settings to use this Turn on Retrieval in Tools settings to use this
@ -221,14 +221,12 @@ const ChatInput = () => {
<li <li
className={twMerge( className={twMerge(
'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]', 'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]',
activeThread?.assistants[0].model.settings?.vision_model activeAssistant?.model.settings?.vision_model
? 'cursor-pointer' ? 'cursor-pointer'
: 'cursor-not-allowed opacity-50' : 'cursor-not-allowed opacity-50'
)} )}
onClick={() => { onClick={() => {
if ( if (activeAssistant?.model.settings?.vision_model) {
activeThread?.assistants[0].model.settings?.vision_model
) {
imageInputRef.current?.click() imageInputRef.current?.click()
setShowAttacmentMenus(false) setShowAttacmentMenus(false)
} }
@ -239,9 +237,7 @@ const ChatInput = () => {
</li> </li>
} }
content="This feature only supports multimodal models." content="This feature only supports multimodal models."
disabled={ disabled={activeAssistant?.model.settings?.vision_model}
activeThread?.assistants[0].model.settings?.vision_model
}
/> />
<Tooltip <Tooltip
side="bottom" side="bottom"
@ -261,8 +257,8 @@ const ChatInput = () => {
</li> </li>
} }
content={ content={
(!activeThread?.assistants[0].tools || (!activeAssistant?.tools ||
!activeThread?.assistants[0].tools[0]?.enabled) && ( !activeAssistant?.tools[0]?.enabled) && (
<span> <span>
Turn on Retrieval in Assistant Settings to use this Turn on Retrieval in Assistant Settings to use this
feature. feature.

View File

@ -80,19 +80,17 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
setEditMessage('') setEditMessage('')
const messageIdx = messages.findIndex((msg) => msg.id === message.id) const messageIdx = messages.findIndex((msg) => msg.id === message.id)
const newMessages = messages.slice(0, messageIdx) const newMessages = messages.slice(0, messageIdx)
if (activeThread) { const toDeleteMessages = messages.slice(messageIdx)
setMessages(activeThread.id, newMessages) const threadId = messages[0].thread_id
await extensionManager await Promise.all(
toDeleteMessages.map(async (message) =>
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages( ?.deleteMessage(message.thread_id, message.id)
activeThread.id,
// Remove all of the messages below this
newMessages
) )
.then(() => { )
sendChatMessage(editPrompt, newMessages) setMessages(threadId, newMessages)
}) sendChatMessage(editPrompt, false, newMessages)
}
} }
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => { const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {

View File

@ -10,15 +10,15 @@ import { MainViewState } from '@/constants/screens'
import { loadModelErrorAtom } from '@/hooks/useActiveModel' import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const LoadModelError = () => { const LoadModelError = () => {
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
const loadModelError = useAtomValue(loadModelErrorAtom) const loadModelError = useAtomValue(loadModelErrorAtom)
const setMainState = useSetAtom(mainViewStateAtom) const setMainState = useSetAtom(mainViewStateAtom)
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeAssistant = useAtomValue(activeAssistantAtom)
const ErrorMessage = () => { const ErrorMessage = () => {
if ( if (
@ -33,9 +33,9 @@ const LoadModelError = () => {
className="cursor-pointer font-medium text-[hsla(var(--app-link))]" className="cursor-pointer font-medium text-[hsla(var(--app-link))]"
onClick={() => { onClick={() => {
setMainState(MainViewState.Settings) setMainState(MainViewState.Settings)
if (activeThread?.assistants[0]?.model.engine) { if (activeAssistant?.model.engine) {
const engine = EngineManager.instance().get( const engine = EngineManager.instance().get(
activeThread.assistants[0].model.engine activeAssistant.model.engine
) )
engine?.name && setSelectedSettingScreen(engine.name) engine?.name && setSelectedSettingScreen(engine.name)
} }

View File

@ -58,12 +58,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
// Should also delete error messages to clear out the error state // Should also delete error messages to clear out the error state
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages( ?.deleteMessage(thread.id, message.id)
thread.id, .catch(console.error)
messages.filter(
(msg) => msg.id !== message.id && msg.status !== MessageStatus.Error
)
)
const updatedThread: Thread = { const updatedThread: Thread = {
...thread, ...thread,
@ -89,10 +85,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
setEditMessage(message.id ?? '') setEditMessage(message.id ?? '')
} }
const onRegenerateClick = async () => {
resendChatMessage(message)
}
if (message.status === MessageStatus.Pending) return null if (message.status === MessageStatus.Pending) return null
return ( return (
@ -122,7 +114,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
ContentType.Pdf && ( ContentType.Pdf && (
<div <div
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2" className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
onClick={onRegenerateClick} onClick={resendChatMessage}
> >
<Tooltip <Tooltip
trigger={ trigger={

View File

@ -17,11 +17,11 @@ import DocMessage from './DocMessage'
import ImageMessage from './ImageMessage' import ImageMessage from './ImageMessage'
import { MarkdownTextMessage } from './MarkdownTextMessage' import { MarkdownTextMessage } from './MarkdownTextMessage'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { import {
editMessageAtom, editMessageAtom,
tokenSpeedAtom, tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const MessageContainer: React.FC< const MessageContainer: React.FC<
ThreadMessage & { isCurrentMessage: boolean } ThreadMessage & { isCurrentMessage: boolean }
@ -29,7 +29,7 @@ const MessageContainer: React.FC<
const isUser = props.role === ChatCompletionRole.User const isUser = props.role === ChatCompletionRole.User
const isSystem = props.role === ChatCompletionRole.System const isSystem = props.role === ChatCompletionRole.System
const editMessage = useAtomValue(editMessageAtom) const editMessage = useAtomValue(editMessageAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeAssistant = useAtomValue(activeAssistantAtom)
const tokenSpeed = useAtomValue(tokenSpeedAtom) const tokenSpeed = useAtomValue(tokenSpeedAtom)
const text = useMemo( const text = useMemo(
@ -75,10 +75,10 @@ const MessageContainer: React.FC<
> >
{isUser {isUser
? props.role ? props.role
: (activeThread?.assistants[0].assistant_name ?? props.role)} : (activeAssistant?.assistant_name ?? props.role)}
</div> </div>
<p className="text-xs font-medium text-gray-400"> <p className="text-xs font-medium text-gray-400">
{displayDate(props.created)} {props.created && displayDate(props.created ?? new Date())}
</p> </p>
{tokenSpeed && {tokenSpeed &&
tokenSpeed.message === props.id && tokenSpeed.message === props.id &&

View File

@ -27,6 +27,7 @@ import RequestDownloadModel from './RequestDownloadModel'
import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom' import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
import { import {
@ -55,9 +56,9 @@ const ThreadCenterPanel = () => {
const setFileUpload = useSetAtom(fileUploadAtom) const setFileUpload = useSetAtom(fileUploadAtom)
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const acceptedFormat: Accept = activeThread?.assistants[0].model.settings const acceptedFormat: Accept = activeAssistant?.model.settings?.vision_model
?.vision_model
? { ? {
'application/pdf': ['.pdf'], 'application/pdf': ['.pdf'],
'image/jpeg': ['.jpeg'], 'image/jpeg': ['.jpeg'],
@ -78,14 +79,13 @@ const ThreadCenterPanel = () => {
if (!experimentalFeature) return if (!experimentalFeature) return
if ( if (
e.dataTransfer.items.length === 1 && e.dataTransfer.items.length === 1 &&
((activeThread?.assistants[0].tools && ((activeAssistant?.tools && activeAssistant?.tools[0]?.enabled) ||
activeThread?.assistants[0].tools[0]?.enabled) || activeAssistant?.model.settings?.vision_model)
activeThread?.assistants[0].model.settings?.vision_model)
) { ) {
setDragOver(true) setDragOver(true)
} else if ( } else if (
activeThread?.assistants[0].tools && activeAssistant?.tools &&
!activeThread?.assistants[0].tools[0]?.enabled !activeAssistant?.tools[0]?.enabled
) { ) {
setDragRejected({ code: 'retrieval-off' }) setDragRejected({ code: 'retrieval-off' })
} else { } else {
@ -100,9 +100,9 @@ const ThreadCenterPanel = () => {
!files || !files ||
files.length !== 1 || files.length !== 1 ||
rejectFiles.length !== 0 || rejectFiles.length !== 0 ||
(activeThread?.assistants[0].tools && (activeAssistant?.tools &&
!activeThread?.assistants[0].tools[0]?.enabled && !activeAssistant?.tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings?.vision_model) !activeAssistant?.model.settings?.vision_model)
) )
return return
const imageType = files[0]?.type.includes('image') const imageType = files[0]?.type.includes('image')
@ -110,10 +110,7 @@ const ThreadCenterPanel = () => {
setDragOver(false) setDragOver(false)
}, },
onDropRejected: (e) => { onDropRejected: (e) => {
if ( if (activeAssistant?.tools && !activeAssistant?.tools[0]?.enabled) {
activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled
) {
setDragRejected({ code: 'retrieval-off' }) setDragRejected({ code: 'retrieval-off' })
} else { } else {
setDragRejected({ code: e[0].errors[0].code }) setDragRejected({ code: e[0].errors[0].code })
@ -186,8 +183,7 @@ const ThreadCenterPanel = () => {
<h6 className="font-bold"> <h6 className="font-bold">
{isDragReject {isDragReject
? `Currently, we only support 1 attachment at the same time with ${ ? `Currently, we only support 1 attachment at the same time with ${
activeThread?.assistants[0].model.settings activeAssistant?.model.settings?.vision_model
?.vision_model
? 'PDF, JPEG, JPG, PNG' ? 'PDF, JPEG, JPG, PNG'
: 'PDF' : 'PDF'
} format` } format`
@ -195,7 +191,7 @@ const ThreadCenterPanel = () => {
</h6> </h6>
{!isDragReject && ( {!isDragReject && (
<p className="mt-2"> <p className="mt-2">
{activeThread?.assistants[0].model.settings?.vision_model {activeAssistant?.model.settings?.vision_model
? 'PDF, JPEG, JPG, PNG' ? 'PDF, JPEG, JPG, PNG'
: 'PDF'} : 'PDF'}
</p> </p>

View File

@ -15,13 +15,15 @@ const ModalEditTitleThread = () => {
const [modalActionThread, setModalActionThread] = useAtom( const [modalActionThread, setModalActionThread] = useAtom(
modalActionThreadAtom modalActionThreadAtom
) )
const [title, setTitle] = useState(modalActionThread.thread?.title as string) const [title, setTitle] = useState(
modalActionThread.thread?.metadata?.title as string
)
useLayoutEffect(() => { useLayoutEffect(() => {
if (modalActionThread.thread?.title) { if (modalActionThread.thread?.metadata?.title) {
setTitle(modalActionThread.thread?.title) setTitle(modalActionThread.thread?.metadata?.title as string)
} }
}, [modalActionThread.thread?.title]) }, [modalActionThread.thread?.metadata])
const onUpdateTitle = useCallback( const onUpdateTitle = useCallback(
(e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => { (e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
@ -30,6 +32,10 @@ const ModalEditTitleThread = () => {
updateThreadMetadata({ updateThreadMetadata({
...modalActionThread?.thread, ...modalActionThread?.thread,
title: title || 'New Thread', title: title || 'New Thread',
metadata: {
...modalActionThread?.thread.metadata,
title: title || 'New Thread',
},
}) })
}, },
[modalActionThread?.thread, title, updateThreadMetadata] [modalActionThread?.thread, title, updateThreadMetadata]

View File

@ -20,7 +20,10 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useRecommendedModel from '@/hooks/useRecommendedModel' import useRecommendedModel from '@/hooks/useRecommendedModel'
import useSetActiveThread from '@/hooks/useSetActiveThread' import useSetActiveThread from '@/hooks/useSetActiveThread'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import {
activeAssistantAtom,
assistantsAtom,
} from '@/helpers/atoms/Assistant.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { import {
@ -34,6 +37,7 @@ import {
const ThreadLeftPanel = () => { const ThreadLeftPanel = () => {
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
const activeThreadId = useAtomValue(getActiveThreadIdAtom) const activeThreadId = useAtomValue(getActiveThreadIdAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { setActiveThread } = useSetActiveThread() const { setActiveThread } = useSetActiveThread()
const assistants = useAtomValue(assistantsAtom) const assistants = useAtomValue(assistantsAtom)
const threadDataReady = useAtomValue(threadDataReadyAtom) const threadDataReady = useAtomValue(threadDataReadyAtom)
@ -67,6 +71,7 @@ const ThreadLeftPanel = () => {
useEffect(() => { useEffect(() => {
if ( if (
threadDataReady && threadDataReady &&
activeAssistant &&
assistants.length > 0 && assistants.length > 0 &&
threads.length === 0 && threads.length === 0 &&
downloadedModels.length > 0 downloadedModels.length > 0
@ -75,7 +80,13 @@ const ThreadLeftPanel = () => {
(model) => model.engine === InferenceEngine.cortex_llamacpp (model) => model.engine === InferenceEngine.cortex_llamacpp
) )
const selectedModel = model[0] || recommendedModel const selectedModel = model[0] || recommendedModel
requestCreateNewThread(assistants[0], selectedModel) requestCreateNewThread(
{
...assistants[0],
...activeAssistant,
},
selectedModel
)
} else if (threadDataReady && !activeThreadId) { } else if (threadDataReady && !activeThreadId) {
setActiveThread(threads[0]) setActiveThread(threads[0])
} }
@ -88,6 +99,7 @@ const ThreadLeftPanel = () => {
setActiveThread, setActiveThread,
recommendedModel, recommendedModel,
downloadedModels, downloadedModels,
activeAssistant,
]) ])
const onContextMenu = (event: React.MouseEvent, thread: Thread) => { const onContextMenu = (event: React.MouseEvent, thread: Thread) => {
@ -138,7 +150,7 @@ const ThreadLeftPanel = () => {
activeThreadId && 'font-medium' activeThreadId && 'font-medium'
)} )}
> >
{thread.title} {thread.title ?? thread.metadata?.title}
</h1> </h1>
</div> </div>
<div <div

View File

@ -14,48 +14,54 @@ import AssistantSetting from '@/screens/Thread/ThreadCenterPanel/AssistantSettin
import { getConfigurationsData } from '@/utils/componentSettings' import { getConfigurationsData } from '@/utils/componentSettings'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const Tools = () => { const Tools = () => {
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const { updateThreadMetadata } = useCreateNewThread() const { updateThreadMetadata } = useCreateNewThread()
const { recommendedModel, downloadedModels } = useRecommendedModel() const { recommendedModel, downloadedModels } = useRecommendedModel()
const componentDataAssistantSetting = getConfigurationsData( const componentDataAssistantSetting = getConfigurationsData(
(activeThread?.assistants[0]?.tools && (activeAssistant?.tools && activeAssistant?.tools[0]?.settings) ?? {}
activeThread?.assistants[0]?.tools[0]?.settings) ??
{}
) )
useEffect(() => { useEffect(() => {
if (!activeThread) return if (!activeThread) return
let model = downloadedModels.find( let model = downloadedModels.find(
(model) => model.id === activeThread.assistants[0].model.id (model) => model.id === activeAssistant?.model.id
) )
if (!model) { if (!model) {
model = recommendedModel model = recommendedModel
} }
setSelectedModel(model) setSelectedModel(model)
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) }, [
recommendedModel,
activeThread,
downloadedModels,
setSelectedModel,
activeAssistant?.model.id,
])
const onRetrievalSwitchUpdate = useCallback( const onRetrievalSwitchUpdate = useCallback(
(enabled: boolean) => { (enabled: boolean) => {
if (!activeThread) return if (!activeThread || !activeAssistant) return
updateThreadMetadata({ updateThreadMetadata({
...activeThread, ...activeThread,
assistants: [ assistants: [
{ {
...activeThread.assistants[0], ...activeAssistant,
tools: [ tools: [
{ {
type: 'retrieval', type: 'retrieval',
enabled: enabled, enabled: enabled,
settings: settings:
(activeThread.assistants[0].tools && (activeAssistant.tools &&
activeThread.assistants[0].tools[0]?.settings) ?? activeAssistant.tools[0]?.settings) ??
{}, {},
}, },
], ],
@ -63,25 +69,25 @@ const Tools = () => {
], ],
}) })
}, },
[activeThread, updateThreadMetadata] [activeAssistant, activeThread, updateThreadMetadata]
) )
const onTimeWeightedRetrieverSwitchUpdate = useCallback( const onTimeWeightedRetrieverSwitchUpdate = useCallback(
(enabled: boolean) => { (enabled: boolean) => {
if (!activeThread) return if (!activeThread || !activeAssistant) return
updateThreadMetadata({ updateThreadMetadata({
...activeThread, ...activeThread,
assistants: [ assistants: [
{ {
...activeThread.assistants[0], ...activeAssistant,
tools: [ tools: [
{ {
type: 'retrieval', type: 'retrieval',
enabled: true, enabled: true,
useTimeWeightedRetriever: enabled, useTimeWeightedRetriever: enabled,
settings: settings:
(activeThread.assistants[0].tools && (activeAssistant.tools &&
activeThread.assistants[0].tools[0]?.settings) ?? activeAssistant.tools[0]?.settings) ??
{}, {},
}, },
], ],
@ -89,15 +95,14 @@ const Tools = () => {
], ],
}) })
}, },
[activeThread, updateThreadMetadata] [activeAssistant, activeThread, updateThreadMetadata]
) )
if (!experimentalFeature) return null if (!experimentalFeature) return null
return ( return (
<Fragment> <Fragment>
{activeThread?.assistants[0]?.tools && {activeAssistant?.tools && componentDataAssistantSetting.length > 0 && (
componentDataAssistantSetting.length > 0 && (
<div className="p-4"> <div className="p-4">
<div className="mb-2"> <div className="mb-2">
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
@ -122,13 +127,13 @@ const Tools = () => {
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
<Switch <Switch
name="retrieval" name="retrieval"
checked={activeThread?.assistants[0].tools[0].enabled} checked={activeAssistant?.tools[0].enabled}
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)} onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
/> />
</div> </div>
</div> </div>
</div> </div>
{activeThread?.assistants[0]?.tools[0].enabled && ( {activeAssistant?.tools[0].enabled && (
<div className="pb-4 pt-2"> <div className="pb-4 pt-2">
<div className="mb-4"> <div className="mb-4">
<div className="item-center mb-2 flex"> <div className="item-center mb-2 flex">
@ -155,11 +160,7 @@ const Tools = () => {
/> />
</div> </div>
<div className="w-full"> <div className="w-full">
<Input <Input value={selectedModel?.name || ''} disabled readOnly />
value={selectedModel?.name || ''}
disabled
readOnly
/>
</div> </div>
</div> </div>
<div className="mb-4"> <div className="mb-4">
@ -214,8 +215,8 @@ const Tools = () => {
<Switch <Switch
name="use-time-weighted-retriever" name="use-time-weighted-retriever"
checked={ checked={
activeThread?.assistants[0].tools[0] activeAssistant?.tools[0].useTimeWeightedRetriever ||
.useTimeWeightedRetriever || false false
} }
onChange={(e) => onChange={(e) =>
onTimeWeightedRetrieverSwitchUpdate(e.target.checked) onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
@ -224,9 +225,7 @@ const Tools = () => {
</div> </div>
</div> </div>
</div> </div>
<AssistantSetting <AssistantSetting componentData={componentDataAssistantSetting} />
componentData={componentDataAssistantSetting}
/>
</div> </div>
)} )}
</div> </div>

View File

@ -38,6 +38,7 @@ import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools' import Tools from './Tools'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
activeThreadAtom, activeThreadAtom,
@ -53,6 +54,7 @@ const ENGINE_SETTINGS = 'Engine Settings'
const ThreadRightPanel = () => { const ThreadRightPanel = () => {
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const selectedModel = useAtomValue(selectedModelAtom) const selectedModel = useAtomValue(selectedModelAtom)
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
@ -154,18 +156,18 @@ const ThreadRightPanel = () => {
const onAssistantInstructionChanged = useCallback( const onAssistantInstructionChanged = useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>) => { (e: React.ChangeEvent<HTMLTextAreaElement>) => {
if (activeThread) if (activeThread && activeAssistant)
updateThreadMetadata({ updateThreadMetadata({
...activeThread, ...activeThread,
assistants: [ assistants: [
{ {
...activeThread.assistants[0], ...activeAssistant,
instructions: e.target.value || '', instructions: e.target.value || '',
}, },
], ],
}) })
}, },
[activeThread, updateThreadMetadata] [activeAssistant, activeThread, updateThreadMetadata]
) )
const resetModel = useDebouncedCallback(() => { const resetModel = useDebouncedCallback(() => {
@ -174,7 +176,7 @@ const ThreadRightPanel = () => {
const onValueChanged = useCallback( const onValueChanged = useCallback(
(key: string, value: string | number | boolean | string[]) => { (key: string, value: string | number | boolean | string[]) => {
if (!activeThread) { if (!activeThread || !activeAssistant) {
return return
} }
@ -186,32 +188,38 @@ const ThreadRightPanel = () => {
}) })
if ( if (
activeThread.assistants[0].model.parameters?.max_tokens && activeAssistant.model.parameters?.max_tokens &&
activeThread.assistants[0].model.settings?.ctx_len activeAssistant.model.settings?.ctx_len
) { ) {
if ( if (
key === 'max_tokens' && key === 'max_tokens' &&
Number(value) > activeThread.assistants[0].model.settings.ctx_len Number(value) > activeAssistant.model.settings.ctx_len
) { ) {
updateModelParameter(activeThread, { updateModelParameter(activeThread, {
params: { params: {
max_tokens: activeThread.assistants[0].model.settings.ctx_len, max_tokens: activeAssistant.model.settings.ctx_len,
}, },
}) })
} }
if ( if (
key === 'ctx_len' && key === 'ctx_len' &&
Number(value) < activeThread.assistants[0].model.parameters.max_tokens Number(value) < activeAssistant.model.parameters.max_tokens
) { ) {
updateModelParameter(activeThread, { updateModelParameter(activeThread, {
params: { params: {
max_tokens: activeThread.assistants[0].model.settings.ctx_len, max_tokens: activeAssistant.model.settings.ctx_len,
}, },
}) })
} }
} }
}, },
[activeThread, resetModel, setEngineParamsUpdate, updateModelParameter] [
activeAssistant,
activeThread,
resetModel,
setEngineParamsUpdate,
updateModelParameter,
]
) )
if (!activeThread) { if (!activeThread) {
@ -250,7 +258,7 @@ const ThreadRightPanel = () => {
<TextArea <TextArea
id="assistant-instructions" id="assistant-instructions"
placeholder="Eg. You are a helpful assistant." placeholder="Eg. You are a helpful assistant."
value={activeThread?.assistants[0].instructions ?? ''} value={activeAssistant?.instructions ?? ''}
autoResize autoResize
onChange={onAssistantInstructionChanged} onChange={onAssistantInstructionChanged}
/> />