Merge pull request #4249 from janhq/feat/threads-messages-requests-to-backend

feat: reroute threads and messages requests to cortex.cpp backend
This commit is contained in:
Louis 2024-12-17 10:56:24 +07:00 committed by GitHub
commit a3b3287327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 1253 additions and 1526 deletions

View File

@ -1,4 +1,10 @@
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
import {
Thread,
ThreadInterface,
ThreadMessage,
MessageInterface,
ThreadAssistantInfo,
} from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'
/**
@ -17,10 +23,21 @@ export abstract class ConversationalExtension
return ExtensionTypeEnum.Conversational
}
abstract getThreads(): Promise<Thread[]>
abstract saveThread(thread: Thread): Promise<void>
abstract listThreads(): Promise<Thread[]>
abstract createThread(thread: Partial<Thread>): Promise<Thread>
abstract modifyThread(thread: Thread): Promise<void>
abstract deleteThread(threadId: string): Promise<void>
abstract addNewMessage(message: ThreadMessage): Promise<void>
abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
abstract getAllMessages(threadId: string): Promise<ThreadMessage[]>
abstract createMessage(message: Partial<ThreadMessage>): Promise<ThreadMessage>
abstract deleteMessage(threadId: string, messageId: string): Promise<void>
abstract listMessages(threadId: string): Promise<ThreadMessage[]>
abstract getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo>
abstract createThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo>
abstract modifyThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo>
abstract modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
}

View File

@ -2,7 +2,6 @@ import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { MessageRequest, Model, ModelEvent } from '../../../types'
import { EngineManager } from './EngineManager'
import { ModelManager } from '../../models/manager'
/**
* Base AIEngine

View File

@ -6,7 +6,6 @@ import {
mkdirSync,
appendFileSync,
createWriteStream,
rmdirSync,
} from 'fs'
import { JanApiRouteConfiguration, RouteConfiguration } from './configuration'
import { join } from 'path'
@ -126,7 +125,7 @@ export const createThread = async (thread: any) => {
}
}
const threadId = generateThreadId(thread.assistants[0].assistant_id)
const threadId = generateThreadId(thread.assistants[0]?.assistant_id)
try {
const updatedThread = {
...thread,
@ -280,7 +279,7 @@ export const models = async (request: any, reply: any) => {
'Content-Type': 'application/json',
}
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ""}`, {
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ''}`, {
method: request.method,
headers: headers,
body: JSON.stringify(request.body),

View File

@ -36,3 +36,10 @@ export type Assistant = {
/** Represents the metadata of the object. */
metadata?: Record<string, unknown>
}
export interface CodeInterpreterTool {
/**
* The type of tool being defined: `code_interpreter`
*/
type: 'code_interpreter'
}

View File

@ -1,3 +1,4 @@
import { CodeInterpreterTool } from '../assistant'
import { ChatCompletionMessage, ChatCompletionRole } from '../inference'
import { ModelInfo } from '../model'
import { Thread } from '../thread'
@ -15,6 +16,10 @@ export type ThreadMessage = {
thread_id: string
/** The assistant id of this thread. **/
assistant_id?: string
/**
* A list of files attached to the message, and the tools they were added to.
*/
attachments?: Array<Attachment> | null
/** The role of the author of this message. **/
role: ChatCompletionRole
/** The content of this message. **/
@ -52,6 +57,11 @@ export type MessageRequest = {
*/
assistantId?: string
/**
* A list of files attached to the message, and the tools they were added to.
*/
attachments: Array<Attachment> | null
/** Messages for constructing a chat completion request **/
messages?: ChatCompletionMessage[]
@ -97,8 +107,7 @@ export enum ErrorCode {
*/
export enum ContentType {
Text = 'text',
Image = 'image',
Pdf = 'pdf',
Image = 'image_url',
}
/**
@ -108,8 +117,15 @@ export enum ContentType {
export type ContentValue = {
value: string
annotations: string[]
name?: string
size?: number
}
/**
* The `ImageContentValue` type defines the shape of a content value object of image type
* @data_transfer_object
*/
export type ImageContentValue = {
detail?: string
url?: string
}
/**
@ -118,5 +134,37 @@ export type ContentValue = {
*/
export type ThreadContent = {
type: ContentType
text: ContentValue
text?: ContentValue
image_url?: ImageContentValue
}
export interface Attachment {
/**
* The ID of the file to attach to the message.
*/
file_id?: string
/**
* The tools to add this file to.
*/
tools?: Array<CodeInterpreterTool | Attachment.AssistantToolsFileSearchTypeOnly>
}
export namespace Attachment {
export interface AssistantToolsFileSearchTypeOnly {
/**
* The type of tool being defined: `file_search`
*/
type: 'file_search'
}
}
/**
* On an incomplete message, details about why the message is incomplete.
*/
export interface IncompleteDetails {
/**
* The reason the message is incomplete.
*/
reason: 'content_filter' | 'max_tokens' | 'run_cancelled' | 'run_expired' | 'run_failed'
}

View File

@ -11,20 +11,20 @@ export interface MessageInterface {
* @param {ThreadMessage} message - The message to be added.
* @returns {Promise<void>} A promise that resolves when the message has been added.
*/
addNewMessage(message: ThreadMessage): Promise<void>
/**
* Writes an array of messages to a specific thread.
* @param {string} threadId - The ID of the thread to write the messages to.
* @param {ThreadMessage[]} messages - The array of messages to be written.
* @returns {Promise<void>} A promise that resolves when the messages have been written.
*/
writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
createMessage(message: ThreadMessage): Promise<ThreadMessage>
/**
* Retrieves all messages from a specific thread.
* @param {string} threadId - The ID of the thread to retrieve the messages from.
* @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread.
*/
getAllMessages(threadId: string): Promise<ThreadMessage[]>
listMessages(threadId: string): Promise<ThreadMessage[]>
/**
* Deletes a specific message from a thread.
* @param {string} threadId - The ID of the thread from which the message will be deleted.
* @param {string} messageId - The ID of the message to be deleted.
* @returns {Promise<void>} A promise that resolves when the message has been successfully deleted.
*/
deleteMessage(threadId: string, messageId: string): Promise<void>
}

View File

@ -11,15 +11,23 @@ export interface ThreadInterface {
* @abstract
* @returns {Promise<Thread[]>} A promise that resolves to an array of threads.
*/
getThreads(): Promise<Thread[]>
listThreads(): Promise<Thread[]>
/**
* Saves a thread.
* Create a thread.
* @abstract
* @param {Thread} thread - The thread to save.
* @returns {Promise<void>} A promise that resolves when the thread is saved.
*/
saveThread(thread: Thread): Promise<void>
createThread(thread: Thread): Promise<Thread>
/**
* modify a thread.
* @abstract
* @param {Thread} thread - The thread to save.
* @returns {Promise<void>} A promise that resolves when the thread is saved.
*/
modifyThread(thread: Thread): Promise<void>
/**
* Deletes a thread.

View File

@ -108,7 +108,7 @@ export const test = base.extend<
})
test.beforeAll(async () => {
await rmSync(path.join(__dirname, '../../test-data'), {
rmSync(path.join(__dirname, '../../test-data'), {
recursive: true,
force: true,
})
@ -122,6 +122,5 @@ test.beforeAll(async () => {
})
test.afterAll(async () => {
// temporally disabling this due to the config for parallel testing WIP
// teardownElectron()
})

View File

@ -2,11 +2,8 @@ import { expect } from '@playwright/test'
import { page, test, TIMEOUT } from '../config/fixtures'
test('renders left navigation panel', async () => {
const settingsBtn = await page
.getByTestId('Thread')
.first()
.isEnabled({ timeout: TIMEOUT })
expect([settingsBtn].filter((e) => !e).length).toBe(0)
const threadBtn = page.getByTestId('Thread').first()
await expect(threadBtn).toBeVisible({ timeout: TIMEOUT })
// Chat section should be there
await page.getByTestId('Local API Server').first().click({
timeout: TIMEOUT,

View File

@ -141,7 +141,7 @@ export default class JanAssistantExtension extends AssistantExtension {
top_k: 2,
chunk_size: 1024,
chunk_overlap: 64,
retrieval_template: `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
retrieval_template: `Use the following pieces of context to answer the question at the end.
----------------
CONTEXT: {CONTEXT}
----------------

View File

@ -9,13 +9,14 @@ export function toolRetrievalUpdateTextSplitter(
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
}
export async function toolRetrievalIngestNewDocument(
thread: string,
file: string,
model: string,
engine: string,
useTimeWeighted: boolean
) {
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
const threadPath = path.dirname(filePath.replace('files', ''))
const threadPath = path.join(getJanDataFolderPath(), 'threads', thread)
const filePath = path.join(getJanDataFolderPath(), 'files', file)
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)

View File

@ -35,6 +35,7 @@ export class RetrievalTool extends InferenceTool {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
data.thread?.id,
docFile,
data.model?.id,
data.model?.engine,

View File

@ -18,12 +18,14 @@
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"ts-loader": "^9.5.0",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4",
"ts-loader": "^9.5.0"
"webpack-cli": "^5.1.4"
},
"dependencies": {
"@janhq/core": "file:../../core"
"@janhq/core": "file:../../core",
"ky": "^1.7.2",
"p-queue": "^8.0.1"
},
"engines": {
"node": ">=18.0.0"

View File

@ -0,0 +1,14 @@
export {}
declare global {
declare const API_URL: string
declare const SOCKET_URL: string
interface Core {
api: APIFunctions
events: EventEmitter
}
interface Window {
core?: Core | undefined
electronAPI?: any | undefined
}
}

View File

@ -1,408 +0,0 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
writeFileSync: jest.fn(),
readdirSync: jest.fn(),
readFileSync: jest.fn(),
appendFileSync: jest.fn(),
rm: jest.fn(),
writeBlob: jest.fn(),
joinPath: jest.fn(),
fileStat: jest.fn(),
},
joinPath: jest.fn(),
ConversationalExtension: jest.fn(),
}))
import { fs } from '@janhq/core'
import JSONConversationalExtension from '.'
describe('JSONConversationalExtension Tests', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should write messages', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const messages = [{ id: '1', thread_id: '1', content: [] }] as any
await extension.writeMessages('1', messages)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should get all messages on string response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should get all messages on object response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: 1 }])
})
it('get all messages return empty on error', async () => {
jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('return empty messages on no messages file', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('should ignore error message', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new image message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new pdf message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
})
describe('test readThread', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should read thread', async () => {
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue(JSON.stringify({ id: '1' }))
const thread = await extension.readThread('1')
expect(thread).toEqual(`{"id":"1"}`)
})
it('getValidThreadDirs should return valid thread directories', async () => {
jest
.spyOn(fs, 'readdirSync')
.mockResolvedValueOnce(['1', '2', '3'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce([])
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
jest.spyOn(fs, 'fileStat').mockResolvedValue({
isDirectory: true,
} as any)
const validThreadDirs = await extension.getValidThreadDirs()
expect(validThreadDirs).toEqual(['1', '2'])
})
})

View File

@ -1,90 +1,71 @@
import {
fs,
joinPath,
ConversationalExtension,
Thread,
ThreadAssistantInfo,
ThreadMessage,
} from '@janhq/core'
import { safelyParseJSON } from './jsonUtil'
import ky from 'ky'
import PQueue from 'p-queue'
type ThreadList = {
data: Thread[]
}
type MessageList = {
data: ThreadMessage[]
}
/**
* JSONConversationalExtension is a ConversationalExtension implementation that provides
* functionality for managing threads.
*/
export default class JSONConversationalExtension extends ConversationalExtension {
private static readonly _threadFolder = 'file://threads'
private static readonly _threadInfoFileName = 'thread.json'
private static readonly _threadMessagesFileName = 'messages.jsonl'
queue = new PQueue({ concurrency: 1 })
/**
* Called when the extension is loaded.
*/
async onLoad() {
if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) {
await fs.mkdir(JSONConversationalExtension._threadFolder)
}
this.queue.add(() => this.healthz())
}
/**
* Called when the extension is unloaded.
*/
onUnload() {
console.debug('JSONConversationalExtension unloaded')
}
onUnload() {}
/**
* Returns a Promise that resolves to an array of Conversation objects.
*/
async getThreads(): Promise<Thread[]> {
try {
const threadDirs = await this.getValidThreadDirs()
const promises = threadDirs.map((dirName) => this.readThread(dirName))
const promiseResults = await Promise.allSettled(promises)
const convos = promiseResults
.map((result) => {
if (result.status === 'fulfilled') {
return typeof result.value === 'object'
? result.value
: safelyParseJSON(result.value)
}
return undefined
})
.filter((convo) => !!convo)
convos.sort(
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
)
return convos
} catch (error) {
console.error(error)
return []
}
async listThreads(): Promise<Thread[]> {
return this.queue.add(() =>
ky
.get(`${API_URL}/v1/threads`)
.json<ThreadList>()
.then((e) => e.data)
) as Promise<Thread[]>
}
/**
* Saves a Thread object to a json file.
* @param thread The Thread object to save.
*/
async saveThread(thread: Thread): Promise<void> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
thread.id,
])
const threadJsonPath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadInfoFileName,
])
if (!(await fs.existsSync(threadDirPath))) {
await fs.mkdir(threadDirPath)
async createThread(thread: Thread): Promise<Thread> {
return this.queue.add(() =>
ky.post(`${API_URL}/v1/threads`, { json: thread }).json<Thread>()
) as Promise<Thread>
}
await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2))
} catch (err) {
console.error(err)
Promise.reject(err)
}
/**
* Saves a Thread object to a json file.
* @param thread The Thread object to save.
*/
async modifyThread(thread: Thread): Promise<void> {
return this.queue
.add(() =>
ky.post(`${API_URL}/v1/threads/${thread.id}`, { json: thread })
)
.then()
}
/**
@ -92,189 +73,126 @@ export default class JSONConversationalExtension extends ConversationalExtension
* @param threadId The ID of the thread to delete.
*/
async deleteThread(threadId: string): Promise<void> {
const path = await joinPath([
JSONConversationalExtension._threadFolder,
`${threadId}`,
])
try {
await fs.rm(path)
} catch (err) {
console.error(err)
}
}
async addNewMessage(message: ThreadMessage): Promise<void> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
message.thread_id,
])
const threadMessagePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
if (message.content[0]?.type === 'image') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
const imagePath = await joinPath([filesPath, `${message.id}.png`])
const base64 = message.content[0].text.annotations[0]
await this.storeImage(base64, imagePath)
if ((await fs.existsSync(imagePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
}
}
if (message.content[0]?.type === 'pdf') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
const filePath = await joinPath([filesPath, `${message.id}.pdf`])
const blob = message.content[0].text.annotations[0]
await this.storeFile(blob, filePath)
if ((await fs.existsSync(filePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
}
}
await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n')
Promise.resolve()
} catch (err) {
Promise.reject(err)
}
}
async storeImage(base64: string, filePath: string): Promise<void> {
const base64Data = base64.replace(/^data:image\/\w+;base64,/, '')
try {
await fs.writeBlob(filePath, base64Data)
} catch (err) {
console.error(err)
}
}
async storeFile(base64: string, filePath: string): Promise<void> {
const base64Data = base64.replace(/^data:application\/pdf;base64,/, '')
try {
await fs.writeBlob(filePath, base64Data)
} catch (err) {
console.error(err)
}
}
async writeMessages(
threadId: string,
messages: ThreadMessage[]
): Promise<void> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
threadId,
])
const threadMessagePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
await fs.writeFileSync(
threadMessagePath,
messages.map((msg) => JSON.stringify(msg)).join('\n') +
(messages.length ? '\n' : '')
)
Promise.resolve()
} catch (err) {
Promise.reject(err)
}
return this.queue
.add(() => ky.delete(`${API_URL}/v1/threads/${threadId}`))
.then()
}
/**
* A promise builder for reading a thread from a file.
* @param threadDirName the thread dir we are reading from.
* @returns data of the thread
* Adds a new message to a specified thread.
* @param message The ThreadMessage object to be added.
* @returns A Promise that resolves when the message has been added.
*/
async readThread(threadDirName: string): Promise<any> {
return fs.readFileSync(
await joinPath([
JSONConversationalExtension._threadFolder,
threadDirName,
JSONConversationalExtension._threadInfoFileName,
]),
'utf-8'
)
}
/**
* Returns a Promise that resolves to an array of thread directories.
* @private
*/
async getValidThreadDirs(): Promise<string[]> {
const fileInsideThread: string[] = await fs.readdirSync(
JSONConversationalExtension._threadFolder
)
const threadDirs: string[] = []
for (let i = 0; i < fileInsideThread.length; i++) {
const path = await joinPath([
JSONConversationalExtension._threadFolder,
fileInsideThread[i],
])
if (!(await fs.fileStat(path))?.isDirectory) continue
const isHavingThreadInfo = (await fs.readdirSync(path)).includes(
JSONConversationalExtension._threadInfoFileName
)
if (!isHavingThreadInfo) {
console.debug(`Ignore ${path} because it does not have thread info`)
continue
}
threadDirs.push(fileInsideThread[i])
}
return threadDirs
}
async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
try {
const threadDirPath = await joinPath([
JSONConversationalExtension._threadFolder,
threadId,
])
const files: string[] = await fs.readdirSync(threadDirPath)
if (
!files.includes(JSONConversationalExtension._threadMessagesFileName)
) {
console.debug(`${threadDirPath} not contains message file`)
return []
}
const messageFilePath = await joinPath([
threadDirPath,
JSONConversationalExtension._threadMessagesFileName,
])
let readResult = await fs.readFileSync(messageFilePath, 'utf-8')
if (typeof readResult === 'object') {
readResult = JSON.stringify(readResult)
}
const result = readResult.split('\n').filter((line) => line !== '')
const messages: ThreadMessage[] = []
result.forEach((line: string) => {
const message = safelyParseJSON(line)
if (message) messages.push(safelyParseJSON(line))
async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
return this.queue.add(() =>
ky
.post(`${API_URL}/v1/threads/${message.thread_id}/messages`, {
json: message,
})
return messages
} catch (err) {
console.error(err)
return []
.json<ThreadMessage>()
) as Promise<ThreadMessage>
}
/**
* Modifies a message in a thread.
* @param message
* @returns
*/
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
return this.queue.add(() =>
ky
.post(
`${API_URL}/v1/threads/${message.thread_id}/messages/${message.id}`,
{
json: message,
}
)
.json<ThreadMessage>()
) as Promise<ThreadMessage>
}
/**
* Deletes a specific message from a thread.
* @param threadId The ID of the thread containing the message.
* @param messageId The ID of the message to be deleted.
* @returns A Promise that resolves when the message has been successfully deleted.
*/
async deleteMessage(threadId: string, messageId: string): Promise<void> {
return this.queue
.add(() =>
ky.delete(`${API_URL}/v1/threads/${threadId}/messages/${messageId}`)
)
.then()
}
/**
* Retrieves all messages for a specified thread.
* @param threadId The ID of the thread to get messages from.
* @returns A Promise that resolves to an array of ThreadMessage objects.
*/
async listMessages(threadId: string): Promise<ThreadMessage[]> {
return this.queue.add(() =>
ky
.get(`${API_URL}/v1/threads/${threadId}/messages?order=asc`)
.json<MessageList>()
.then((e) => e.data)
) as Promise<ThreadMessage[]>
}
/**
* Retrieves the assistant information for a specified thread.
* @param threadId The ID of the thread for which to retrieve assistant information.
* @returns A Promise that resolves to a ThreadAssistantInfo object containing
* the details of the assistant associated with the specified thread.
*/
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky.get(`${API_URL}/v1/assistants/${threadId}`).json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Creates a new assistant for the specified thread.
* @param threadId The ID of the thread for which the assistant is being created.
* @param assistant The information about the assistant to be created.
* @returns A Promise that resolves to the newly created ThreadAssistantInfo object.
*/
async createThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky
.post(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
.json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Modifies an existing assistant for the specified thread.
* @param threadId The ID of the thread for which the assistant is being modified.
* @param assistant The updated information for the assistant.
* @returns A Promise that resolves to the updated ThreadAssistantInfo object.
*/
async modifyThreadAssistant(
threadId: string,
assistant: ThreadAssistantInfo
): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky
.patch(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
.json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
* Do health check on cortex.cpp
* @returns
*/
healthz(): Promise<void> {
return ky
.get(`${API_URL}/healthz`, {
retry: { limit: 20, delay: () => 500, methods: ['get'] },
})
.then(() => {})
}
}

View File

@ -1,14 +0,0 @@
// Note about performance
// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
// v8 4.5 and above can optimise try/catch
export function safelyParseJSON(json) {
// This function cannot be optimised, it's best to
// keep it small!
var parsed
try {
parsed = JSON.parse(json)
} catch (e) {
return undefined
}
return parsed // Could be undefined!
}

View File

@ -17,7 +17,12 @@ module.exports = {
filename: 'index.js', // Adjust the output file name as needed
library: { type: 'module' }, // Specify ESM output format
},
plugins: [new webpack.DefinePlugin({})],
plugins: [
new webpack.DefinePlugin({
API_URL: JSON.stringify('http://127.0.0.1:39291'),
SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
}),
],
resolve: {
extensions: ['.ts', '.js'],
},

View File

@ -1 +1 @@
1.0.4
1.0.5-rc1

View File

@ -2,7 +2,7 @@
set BIN_PATH=./bin
set SHARED_PATH=./../../electron/shared
set /p CORTEX_VERSION=<./bin/version.txt
set ENGINE_VERSION=0.1.40
set ENGINE_VERSION=0.1.42
@REM Download cortex.llamacpp binaries
set DOWNLOAD_URL=https://github.com/janhq/cortex.llamacpp/releases/download/v%ENGINE_VERSION%/cortex.llamacpp-%ENGINE_VERSION%-windows-amd64

View File

@ -2,7 +2,7 @@
# Read CORTEX_VERSION
CORTEX_VERSION=$(cat ./bin/version.txt)
ENGINE_VERSION=0.1.40
ENGINE_VERSION=0.1.42
CORTEX_RELEASE_URL="https://github.com/janhq/cortex.cpp/releases/download"
ENGINE_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}/cortex.llamacpp-${ENGINE_VERSION}"
CUDA_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}"

View File

@ -120,7 +120,7 @@ export default [
SETTINGS: JSON.stringify(defaultSettingJson),
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'),
CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.42'),
}),
// Allow json resolution
json(),

View File

@ -18,14 +18,14 @@ import { isLocalEngine } from '@/utils/modelEngine'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
const setMainState = useSetAtom(mainViewStateAtom)
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const defaultDesc = () => {
return (
@ -46,7 +46,7 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
}
const getEngine = () => {
const engineName = activeThread?.assistants?.[0]?.model?.engine
const engineName = activeAssistant?.model?.engine
return engineName ? EngineManager.instance().get(engineName) : null
}
@ -89,7 +89,9 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
</span>
) : (
<>
<AutoLink text={message.content[0].text.value} />
{message?.content[0]?.text?.value && (
<AutoLink text={message?.content[0]?.text?.value} />
)}
{defaultDesc()}
</>
)}

View File

@ -46,6 +46,7 @@ import {
import { extensionManager } from '@/extension'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import {
configuredModelsAtom,
@ -75,6 +76,7 @@ const ModelDropdown = ({
const [searchText, setSearchText] = useState('')
const [open, setOpen] = useState(false)
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const downloadingModels = useAtomValue(getDownloadingModelAtom)
const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
@ -151,17 +153,24 @@ const ModelDropdown = ({
useEffect(() => {
if (!activeThread) return
const modelId = activeThread?.assistants?.[0]?.model?.id
const modelId = activeAssistant?.model?.id
let model = downloadedModels.find((model) => model.id === modelId)
if (!model) {
model = recommendedModel
}
setSelectedModel(model)
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
}, [
recommendedModel,
activeThread,
downloadedModels,
setSelectedModel,
activeAssistant?.model?.id,
])
const onClickModelItem = useCallback(
async (modelId: string) => {
if (!activeAssistant) return
const model = downloadedModels.find((m) => m.id === modelId)
setSelectedModel(model)
setOpen(false)
@ -172,14 +181,14 @@ const ModelDropdown = ({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
...activeAssistant,
tools: [
{
type: 'retrieval',
enabled: isModelSupportRagAndTools(model as Model),
settings: {
...(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings),
...(activeAssistant.tools &&
activeAssistant.tools[0]?.settings),
},
},
],
@ -219,13 +228,14 @@ const ModelDropdown = ({
}
},
[
activeAssistant,
downloadedModels,
activeThread,
setSelectedModel,
activeThread,
updateThreadMetadata,
isModelSupportRagAndTools,
setThreadModelParams,
updateModelParameter,
updateThreadMetadata,
]
)

View File

@ -8,7 +8,7 @@ import { FileInfo } from '@/types/file'
export const editPromptAtom = atom<string>('')
export const currentPromptAtom = atom<string>('')
export const fileUploadAtom = atom<FileInfo[]>([])
export const fileUploadAtom = atom<FileInfo | undefined>()
export const searchAtom = atom<string>('')

View File

@ -31,6 +31,7 @@ import {
addNewMessageAtom,
updateMessageAtom,
tokenSpeedAtom,
deleteMessageAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import {
@ -49,6 +50,7 @@ export default function ModelHandler() {
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
const deleteMessage = useSetAtom(deleteMessageAtom)
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
@ -86,7 +88,7 @@ export default function ModelHandler() {
}, [activeModelParams])
const onNewMessageResponse = useCallback(
(message: ThreadMessage) => {
async (message: ThreadMessage) => {
if (message.type === MessageRequestType.Thread) {
addNewMessage(message)
}
@ -154,12 +156,15 @@ export default function ModelHandler() {
...thread,
title: cleanedMessageContent,
metadata: thread.metadata,
metadata: {
...thread.metadata,
title: cleanedMessageContent,
},
}
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({
?.modifyThread({
...updatedThread,
})
.then(() => {
@ -233,7 +238,9 @@ export default function ModelHandler() {
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
if (!thread) return
const messageContent = message.content[0]?.text?.value
const metadata = {
...thread.metadata,
...(messageContent && { lastMessage: messageContent }),
@ -246,15 +253,19 @@ export default function ModelHandler() {
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({
?.modifyThread({
...thread,
metadata,
})
// If this is not the summary of the Thread, don't need to add it to the Thread
extensionManager
;(async () => {
const updatedMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(message)
?.createMessage(message)
if (updatedMessage) {
deleteMessage(message.id)
addNewMessage(updatedMessage)
}
})()
// Attempt to generate the title of the Thread when needed
generateThreadTitle(message, thread)
@ -279,7 +290,9 @@ export default function ModelHandler() {
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
// If this is the first ever prompt in the thread
if (thread.title?.trim() !== defaultThreadTitle) {
if (
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
) {
return
}
@ -292,11 +305,14 @@ export default function ModelHandler() {
const updatedThread: Thread = {
...thread,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
metadata: thread.metadata,
metadata: {
...thread.metadata,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
},
}
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread({
?.modifyThread({
...updatedThread,
})
.then(() => {
@ -313,7 +329,7 @@ export default function ModelHandler() {
if (!threadMessages || threadMessages.length === 0) return
const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0].content[0].text.value}"`
const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0]?.content[0]?.text?.value}"`
// Prompt: Given this query from user {query}, return to me the summary in 10 words as the title
const msgId = ulid()
@ -330,6 +346,7 @@ export default function ModelHandler() {
id: msgId,
threadId: message.thread_id,
type: MessageRequestType.Summary,
attachments: [],
messages,
model: {
...activeModelRef.current,

View File

@ -1,4 +1,12 @@
import { Assistant } from '@janhq/core'
import { Assistant, ThreadAssistantInfo } from '@janhq/core'
import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
export const assistantsAtom = atom<Assistant[]>([])
/**
* Get the current active assistant
*/
export const activeAssistantAtom = atomWithStorage<
ThreadAssistantInfo | undefined
>('activeAssistant', undefined, undefined, { getOnInit: true })

View File

@ -6,6 +6,8 @@ import {
} from '@janhq/core'
import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
import {
getActiveThreadIdAtom,
updateThreadStateLastMessageAtom,
@ -13,15 +15,23 @@ import {
import { TokenSpeed } from '@/types/token'
const CHAT_MESSAGE_NAME = 'chatMessages'
/**
* Stores all chat messages for all threads
*/
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
export const chatMessages = atomWithStorage<Record<string, ThreadMessage[]>>(
CHAT_MESSAGE_NAME,
{},
undefined,
{ getOnInit: true }
)
/**
* Stores the status of the messages load for each thread
*/
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
export const readyThreadsMessagesAtom = atomWithStorage<
Record<string, boolean>
>('currentThreadMessages', {}, undefined, { getOnInit: true })
/**
* Store the token speed for current message
@ -34,6 +44,7 @@ export const getCurrentChatMessagesAtom = atom<ThreadMessage[]>((get) => {
const activeThreadId = get(getActiveThreadIdAtom)
if (!activeThreadId) return []
const messages = get(chatMessages)[activeThreadId]
if (!Array.isArray(messages)) return []
return messages ?? []
})

View File

@ -58,9 +58,11 @@ describe('Model.atom.ts', () => {
setAtom.current({ id: '1' } as any)
})
expect(getAtom.current).toEqual([{ id: '1' }])
act(() => {
reset.current([])
})
})
})
describe('removeDownloadingModelAtom', () => {
it('should remove downloading model', async () => {
@ -83,9 +85,11 @@ describe('Model.atom.ts', () => {
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
act(() => {
reset.current([])
})
})
})
describe('removeDownloadedModelAtom', () => {
it('should remove downloaded model', async () => {
@ -113,9 +117,11 @@ describe('Model.atom.ts', () => {
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
act(() => {
reset.current([])
})
})
})
describe('importingModelAtom', () => {
afterEach(() => {

View File

@ -8,6 +8,7 @@ import { toaster } from '@/containers/Toast'
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -34,6 +35,7 @@ export function useActiveModel() {
const setLoadModelError = useSetAtom(loadModelErrorAtom)
const pendingModelLoad = useRef(false)
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const downloadedModelsRef = useRef<Model[]>([])
@ -79,12 +81,12 @@ export function useActiveModel() {
}
/// Apply thread model settings
if (activeThread?.assistants[0]?.model.id === modelId) {
if (activeAssistant?.model.id === modelId) {
model = {
...model,
settings: {
...model.settings,
...activeThread.assistants[0].model.settings,
...activeAssistant?.model.settings,
},
}
}

View File

@ -67,7 +67,7 @@ describe('useCreateNewThread', () => {
} as any)
})
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
expect(mockSetAtom).toHaveBeenCalledTimes(1)
expect(extensionManager.get).toHaveBeenCalled()
})
@ -104,7 +104,7 @@ describe('useCreateNewThread', () => {
await result.current.requestCreateNewThread({
id: 'assistant1',
name: 'Assistant 1',
instructions: "Hello Jan Assistant",
instructions: 'Hello Jan Assistant',
model: {
id: 'model1',
parameters: [],
@ -113,16 +113,8 @@ describe('useCreateNewThread', () => {
} as any)
})
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
expect(extensionManager.get).toHaveBeenCalled()
expect(mockSetAtom).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
assistants: expect.arrayContaining([
expect.objectContaining({ instructions: 'Hello Jan Assistant' }),
]),
})
)
})
it('should create a new thread with previous instructions', async () => {
@ -166,16 +158,8 @@ describe('useCreateNewThread', () => {
} as any)
})
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
expect(extensionManager.get).toHaveBeenCalled()
expect(mockSetAtom).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
assistants: expect.arrayContaining([
expect.objectContaining({ instructions: 'Hello Jan' }),
]),
})
)
})
it('should show a warning toast if trying to create an empty thread', async () => {
@ -212,13 +196,12 @@ describe('useCreateNewThread', () => {
const { result } = renderHook(() => useCreateNewThread())
const mockThread = { id: 'thread1', title: 'Test Thread' }
const mockThread = { id: 'thread1', title: 'Test Thread', assistants: [{}] }
await act(async () => {
await result.current.updateThreadMetadata(mockThread as any)
})
expect(mockUpdateThread).toHaveBeenCalledWith(mockThread)
expect(extensionManager.get).toHaveBeenCalled()
})
})

View File

@ -1,7 +1,6 @@
import { useCallback } from 'react'
import {
Assistant,
ConversationalExtension,
ExtensionTypeEnum,
Thread,
@ -9,8 +8,11 @@ import {
ThreadState,
AssistantTool,
Model,
Assistant,
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai'
@ -18,7 +20,6 @@ import { fileUploadAtom } from '@/containers/Providers/Jotai'
import { toaster } from '@/containers/Toast'
import { isLocalEngine } from '@/utils/modelEngine'
import { generateThreadId } from '@/utils/thread'
import { useActiveModel } from './useActiveModel'
import useRecommendedModel from './useRecommendedModel'
@ -28,6 +29,7 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
@ -35,7 +37,6 @@ import {
updateThreadAtom,
setThreadModelParamsAtom,
isGeneratingResponseAtom,
activeThreadAtom,
} from '@/helpers/atoms/Thread.atom'
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
@ -65,7 +66,7 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
const activeThread = useAtomValue(activeThreadAtom)
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
@ -76,7 +77,7 @@ export const useCreateNewThread = () => {
const { stopInference } = useActiveModel()
const requestCreateNewThread = async (
assistant: Assistant,
assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant,
model?: Model | undefined
) => {
// Stop generating if any
@ -127,7 +128,7 @@ export const useCreateNewThread = () => {
const createdAt = Date.now()
let instructions: string | undefined = assistant.instructions
if (copyOverInstructionEnabled) {
instructions = activeThread?.assistants[0]?.instructions ?? undefined
instructions = activeAssistant?.instructions ?? undefined
}
const assistantInfo: ThreadAssistantInfo = {
assistant_id: assistant.id,
@ -142,45 +143,94 @@ export const useCreateNewThread = () => {
instructions,
}
const threadId = generateThreadId(assistant.id)
const thread: Thread = {
id: threadId,
const thread: Partial<Thread> = {
object: 'thread',
title: 'New Thread',
assistants: [assistantInfo],
created: createdAt,
updated: createdAt,
metadata: {
title: 'New Thread',
},
}
// add the new thread on top of the thread list to the state
//TODO: Why do we have thread list then thread states? Should combine them
createNewThread(thread)
try {
const createdThread = await persistNewThread(thread, assistantInfo)
if (!createdThread) throw 'Thread created failed.'
createNewThread(createdThread)
setSelectedModel(defaultModel)
setThreadModelParams(thread.id, {
setThreadModelParams(createdThread.id, {
...defaultModel?.settings,
...defaultModel?.parameters,
...overriddenSettings,
})
// Delete the file upload state
setFileUpload([])
// Update thread metadata
await updateThreadMetadata(thread)
setActiveThread(thread)
setFileUpload(undefined)
setActiveThread(createdThread)
} catch (ex) {
return toaster({
title: 'Thread created failed.',
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
type: 'error',
})
}
}
const updateThreadExtension = (thread: Thread) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread(thread)
}
const updateAssistantExtension = (
threadId: string,
assistant: ThreadAssistantInfo
) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThreadAssistant(threadId, assistant)
}
const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300)
const updateAssistantCallback = useDebouncedCallback(
updateAssistantExtension,
300
)
const updateThreadMetadata = useCallback(
async (thread: Thread) => {
updateThread(thread)
setActiveAssistant(thread.assistants[0])
updateThreadCallback(thread)
updateAssistantCallback(thread.id, thread.assistants[0])
},
[
updateThread,
setActiveAssistant,
updateThreadCallback,
updateAssistantCallback,
]
)
const persistNewThread = async (
thread: Partial<Thread>,
assistantInfo: ThreadAssistantInfo
): Promise<Thread | undefined> => {
return await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createThread(thread)
.then(async (thread) => {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(thread)
},
[updateThread]
)
?.createThreadAssistant(thread.id, assistantInfo)
return thread
})
}
return {
requestCreateNewThread,

View File

@ -2,8 +2,7 @@ import { renderHook, act } from '@testing-library/react'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import useDeleteThread from './useDeleteThread'
import { extensionManager } from '@/extension/ExtensionManager'
import { toaster } from '@/containers/Toast'
import { useCreateNewThread } from './useCreateNewThread'
// Mock the necessary dependencies
// Mock dependencies
jest.mock('jotai', () => ({
@ -12,6 +11,7 @@ jest.mock('jotai', () => ({
useAtom: jest.fn(),
atom: jest.fn(),
}))
jest.mock('./useCreateNewThread')
jest.mock('@/extension/ExtensionManager')
jest.mock('@/containers/Toast')
@ -27,8 +27,13 @@ describe('useDeleteThread', () => {
]
const mockSetThreads = jest.fn()
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
;(useSetAtom as jest.Mock).mockReturnValue(() => {})
;(useCreateNewThread as jest.Mock).mockReturnValue({})
const mockDeleteThread = jest.fn().mockImplementation(() => ({
catch: () => jest.fn,
}))
const mockDeleteThread = jest.fn()
extensionManager.get = jest.fn().mockReturnValue({
deleteThread: mockDeleteThread,
})
@ -50,12 +55,17 @@ describe('useDeleteThread', () => {
const mockCleanMessages = jest.fn()
;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
const mockCreateNewThread = jest.fn()
;(useCreateNewThread as jest.Mock).mockReturnValue({
requestCreateNewThread: mockCreateNewThread,
})
const mockWriteMessages = jest.fn()
const mockSaveThread = jest.fn()
const mockDeleteThread = jest.fn().mockResolvedValue({})
extensionManager.get = jest.fn().mockReturnValue({
writeMessages: mockWriteMessages,
saveThread: mockSaveThread,
getThreadAssistant: jest.fn().mockResolvedValue({}),
deleteThread: mockDeleteThread,
})
const { result } = renderHook(() => useDeleteThread())
@ -64,20 +74,18 @@ describe('useDeleteThread', () => {
await result.current.cleanThread('thread1')
})
expect(mockWriteMessages).toHaveBeenCalled()
expect(mockSaveThread).toHaveBeenCalledWith(
expect.objectContaining({
id: 'thread1',
title: 'New Thread',
metadata: expect.objectContaining({ lastMessage: undefined }),
})
)
expect(mockDeleteThread).toHaveBeenCalled()
expect(mockCreateNewThread).toHaveBeenCalled()
})
it('should handle errors when deleting a thread', async () => {
const mockThreads = [{ id: 'thread1', title: 'Thread 1' }]
const mockSetThreads = jest.fn()
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
const mockCreateNewThread = jest.fn()
;(useCreateNewThread as jest.Mock).mockReturnValue({
requestCreateNewThread: mockCreateNewThread,
})
const mockDeleteThread = jest
.fn()
@ -98,8 +106,6 @@ describe('useDeleteThread', () => {
expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error))
expect(mockSetThreads).not.toHaveBeenCalled()
expect(toaster).not.toHaveBeenCalled()
consoleErrorSpy.mockRestore()
})

View File

@ -1,13 +1,6 @@
import { useCallback } from 'react'
import {
ChatCompletionRole,
ExtensionTypeEnum,
ConversationalExtension,
fs,
joinPath,
Thread,
} from '@janhq/core'
import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -15,89 +8,63 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai'
import { toaster } from '@/containers/Toast'
import { useCreateNewThread } from './useCreateNewThread'
import { extensionManager } from '@/extension/ExtensionManager'
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
import {
chatMessages,
cleanChatMessageAtom as cleanChatMessagesAtom,
deleteChatMessageAtom as deleteChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
setActiveThreadIdAtom,
deleteThreadStateAtom,
updateThreadStateLastMessageAtom,
updateThreadAtom,
} from '@/helpers/atoms/Thread.atom'
export default function useDeleteThread() {
const [threads, setThreads] = useAtom(threadsAtom)
const messages = useAtomValue(chatMessages)
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
const { requestCreateNewThread } = useCreateNewThread()
const assistants = useAtomValue(assistantsAtom)
const models = useAtomValue(downloadedModelsAtom)
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const cleanThread = useCallback(
async (threadId: string) => {
cleanMessages(threadId)
const thread = threads.find((c) => c.id === threadId)
if (!thread) return
const updatedMessages = (messages[threadId] ?? []).filter(
(msg) => msg.role === ChatCompletionRole.System
)
// remove files
try {
const threadFolderPath = await joinPath([
janDataFolderPath,
'threads',
threadId,
])
const threadFilesPath = await joinPath([threadFolderPath, 'files'])
const threadMemoryPath = await joinPath([threadFolderPath, 'memory'])
await fs.rm(threadFilesPath)
await fs.rm(threadMemoryPath)
} catch (err) {
console.warn('Error deleting thread files', err)
}
await extensionManager
const assistantInfo = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(threadId, updatedMessages)
?.getThreadAssistant(thread.id)
thread.metadata = {
...thread.metadata,
}
if (!assistantInfo) return
const model = models.find((c) => c.id === assistantInfo?.model?.id)
const updatedThread: Thread = {
...thread,
title: 'New Thread',
metadata: { ...thread.metadata, lastMessage: undefined },
}
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
updateThreadLastMessage(threadId, undefined)
updateThread(updatedThread)
requestCreateNewThread(
{
...assistantInfo,
id: assistants[0].id,
name: assistants[0].name,
},
[
cleanMessages,
threads,
messages,
updateThreadLastMessage,
updateThread,
janDataFolderPath,
]
model
? {
...model,
parameters: assistantInfo?.model?.parameters ?? {},
settings: assistantInfo?.model?.settings ?? {},
}
: undefined
)
// Delete this thread
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteThread(threadId)
.catch(console.error)
},
[assistants, models, requestCreateNewThread, threads]
)
const deleteThread = async (threadId: string) => {
@ -105,10 +72,10 @@ export default function useDeleteThread() {
alert('No active thread')
return
}
try {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteThread(threadId)
.catch(console.error)
const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads)
@ -127,9 +94,6 @@ export default function useDeleteThread() {
} else {
setActiveThreadId(undefined)
}
} catch (err) {
console.error(err)
}
}
return {

View File

@ -1,3 +1,6 @@
/**
* @jest-environment jsdom
*/
// useDropModelBinaries.test.ts
import { renderHook, act } from '@testing-library/react'
@ -18,6 +21,7 @@ jest.mock('jotai', () => ({
jest.mock('uuid')
jest.mock('@/utils/file')
jest.mock('@/containers/Toast')
jest.mock("@uppy/core")
describe('useDropModelBinaries', () => {
const mockSetImportingModels = jest.fn()

View File

@ -2,6 +2,7 @@ import { openFileExplorer, joinPath, baseName } from '@janhq/core'
import { useAtomValue } from 'jotai'
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -9,13 +10,14 @@ export const usePath = () => {
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
const activeThread = useAtomValue(activeThreadAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const onRevealInFinder = async (type: string) => {
// TODO: this logic should be refactored.
if (type !== 'Model' && !activeThread) return
let filePath = undefined
const assistantId = activeThread?.assistants[0]?.assistant_id
const assistantId = activeAssistant?.assistant_id
switch (type) {
case 'Engine':
case 'Thread':

View File

@ -6,6 +6,7 @@ import { atom, useAtomValue } from 'jotai'
import { activeModelAtom } from './useActiveModel'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -28,6 +29,7 @@ export default function useRecommendedModel() {
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
const activeThread = useAtomValue(activeThreadAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
const models = downloadedModels.sort((a, b) =>
@ -45,8 +47,8 @@ export default function useRecommendedModel() {
> => {
const models = await getAndSortDownloadedModels()
if (!activeThread) return
const modelId = activeThread.assistants[0]?.model.id
if (!activeThread || !activeAssistant) return
const modelId = activeAssistant.model.id
const model = models.find((model) => model.id === modelId)
if (model) {

View File

@ -10,6 +10,7 @@ import {
ConversationalExtension,
EngineManager,
ToolManager,
ThreadAssistantInfo,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -28,6 +29,7 @@ import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
addNewMessageAtom,
deleteMessageAtom,
@ -48,6 +50,7 @@ export const reloadModelAtom = atom(false)
export default function useSendChatMessage() {
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
@ -68,6 +71,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>()
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
const selectedModelRef = useRef<Model | undefined>()
@ -84,36 +88,37 @@ export default function useSendChatMessage() {
selectedModelRef.current = selectedModel
}, [selectedModel])
const resendChatMessage = async (currentMessage: ThreadMessage) => {
useEffect(() => {
activeAssistantRef.current = activeAssistant
}, [activeAssistant])
const resendChatMessage = async () => {
// Delete last response before regenerating
const newConvoData = currentMessages
let toSendMessage = currentMessage
const newConvoData = Array.from(currentMessages)
let toSendMessage = newConvoData.pop()
do {
deleteMessage(currentMessage.id)
const msg = newConvoData.pop()
if (!msg) break
toSendMessage = msg
deleteMessage(toSendMessage.id ?? '')
} while (toSendMessage.role !== ChatCompletionRole.User)
if (activeThreadRef.current) {
while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(activeThreadRef.current.id, newConvoData)
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
.catch(console.error)
deleteMessage(toSendMessage.id ?? '')
toSendMessage = newConvoData.pop()
}
sendChatMessage(toSendMessage.content[0]?.text.value)
if (toSendMessage?.content[0]?.text?.value)
sendChatMessage(toSendMessage.content[0].text.value, true)
}
const sendChatMessage = async (
message: string,
isResend: boolean = false,
messages?: ThreadMessage[]
) => {
if (!message || message.trim().length === 0) return
if (!activeThreadRef.current) {
console.error('No active thread')
if (!activeThreadRef.current || !activeAssistantRef.current) {
console.error('No active thread or assistant')
return
}
@ -129,21 +134,19 @@ export default function useSendChatMessage() {
setCurrentPrompt('')
setEditPrompt('')
let base64Blob = fileUpload[0]
? await getBase64(fileUpload[0].file)
: undefined
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
if (base64Blob && fileUpload[0]?.type === 'image') {
if (base64Blob && fileUpload?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
const modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
selectedModelRef?.current ?? activeAssistantRef.current?.model
// Fallback support for previous broken threads
if (activeThreadRef.current?.assistants[0]?.model?.id === '*') {
activeThreadRef.current.assistants[0].model = {
if (activeAssistantRef.current?.model?.id === '*') {
activeAssistantRef.current.model = {
id: modelRequest.id,
settings: modelRequest.settings,
parameters: modelRequest.parameters,
@ -163,9 +166,10 @@ export default function useSendChatMessage() {
},
activeThreadRef.current,
messages ?? currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
).addSystemMessage(activeAssistantRef.current?.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
if (!isResend) {
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
@ -174,9 +178,6 @@ export default function useSendChatMessage() {
const newMessage = threadMessageBuilder.build()
// Push to states
addNewMessage(newMessage)
// Update thread state
const updatedThread: Thread = {
...activeThreadRef.current,
@ -189,20 +190,25 @@ export default function useSendChatMessage() {
updateThread(updatedThread)
// Add message
await extensionManager
const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(newMessage)
?.createMessage(newMessage)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started
const modelId =
selectedModelRef.current?.id ??
activeThreadRef.current.assistants[0].model.id
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload([])
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId) {
if (modelRef.current?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThreadRef.current.id, false)
@ -214,9 +220,7 @@ export default function useSendChatMessage() {
// Process message request with Assistants tools
const request = await ToolManager.instance().process(
requestBuilder.build(),
activeThreadRef.current.assistants?.flatMap(
(assistant) => assistant.tools ?? []
) ?? []
activeAssistantRef?.current.tools ?? []
)
// Request for inference

View File

@ -1,12 +1,10 @@
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension'
import {
readyThreadsMessagesAtom,
setConvoMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import {
setActiveThreadIdAtom,
setThreadModelParamsAtom,
@ -17,21 +15,27 @@ export default function useSetActiveThread() {
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom)
const setActiveAssistant = useSetAtom(activeAssistantAtom)
const setActiveThread = async (thread: Thread) => {
// Load local messages only if there are no messages in the state
if (!readyMessageThreads[thread?.id]) {
const messages = await getLocalThreadMessage(thread?.id)
setThreadMessage(thread?.id, messages)
}
if (!thread?.id) return
setActiveThreadId(thread?.id)
try {
const assistantInfo = await getThreadAssistant(thread.id)
setActiveAssistant(assistantInfo)
// Load local messages only if there are no messages in the state
const messages = await getLocalThreadMessage(thread.id).catch(() => [])
const modelParams: ModelParams = {
...thread?.assistants[0]?.model?.parameters,
...thread?.assistants[0]?.model?.settings,
...assistantInfo?.model?.parameters,
...assistantInfo?.model?.settings,
}
setThreadModelParams(thread?.id, modelParams)
setThreadMessage(thread.id, messages)
} catch (e) {
console.error(e)
}
}
return { setActiveThread }
@ -40,4 +44,9 @@ export default function useSetActiveThread() {
const getLocalThreadMessage = async (threadId: string) =>
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getAllMessages(threadId) ?? []
?.listMessages(threadId) ?? []
const getThreadAssistant = async (threadId: string) =>
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getThreadAssistant(threadId)

View File

@ -78,7 +78,7 @@ describe('useThreads', () => {
// Mock extensionManager
const mockGetThreads = jest.fn().mockResolvedValue(mockThreads)
;(extensionManager.get as jest.Mock).mockReturnValue({
getThreads: mockGetThreads,
listThreads: mockGetThreads,
})
const { result } = renderHook(() => useThreads())
@ -119,7 +119,7 @@ describe('useThreads', () => {
it('should handle empty threads', async () => {
// Mock empty threads
;(extensionManager.get as jest.Mock).mockReturnValue({
getThreads: jest.fn().mockResolvedValue([]),
listThreads: jest.fn().mockResolvedValue([]),
})
const mockSetThreadStates = jest.fn()

View File

@ -68,6 +68,6 @@ const useThreads = () => {
const getLocalThreads = async (): Promise<Thread[]> =>
(await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getThreads()) ?? []
?.listThreads()) ?? []
export default useThreads

View File

@ -1,7 +1,12 @@
import { renderHook, act } from '@testing-library/react'
import { useAtom } from 'jotai'
// Mock dependencies
jest.mock('ulidx')
jest.mock('@/extension')
jest.mock('jotai', () => ({
...jest.requireActual('jotai'),
useAtom: jest.fn(),
}))
import useUpdateModelParameters from './useUpdateModelParameters'
import { extensionManager } from '@/extension'
@ -13,7 +18,8 @@ let model: any = {
}
let extension: any = {
saveThread: jest.fn(),
modifyThread: jest.fn(),
modifyThreadAssistant: jest.fn(),
}
const mockThread: any = {
@ -35,6 +41,7 @@ const mockThread: any = {
describe('useUpdateModelParameters', () => {
beforeAll(() => {
jest.clearAllMocks()
jest.useFakeTimers()
jest.mock('./useRecommendedModel', () => ({
useRecommendedModel: () => ({
recommendedModel: model,
@ -45,6 +52,12 @@ describe('useUpdateModelParameters', () => {
})
it('should update model parameters and save thread when params are valid', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockValidParameters: any = {
params: {
// Inference
@ -76,7 +89,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
@ -84,10 +98,11 @@ describe('useUpdateModelParameters', () => {
await result.current.updateModelParameter(mockThread, mockValidParameters)
})
jest.runAllTimers()
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
id: 'assistant-1',
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
@ -110,18 +125,19 @@ describe('useUpdateModelParameters', () => {
llama_model_path: 'path',
mmproj: 'mmproj',
},
id: 'model-1',
engine: 'nitro',
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should not update invalid model parameters', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockInvalidParameters: any = {
params: {
// Inference
@ -153,7 +169,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
@ -164,14 +181,17 @@ describe('useUpdateModelParameters', () => {
)
})
jest.runAllTimers()
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
id: 'assistant-1',
model: {
engine: 'nitro',
id: 'model-1',
parameters: {
max_tokens: 1000,
token_limit: 1000,
max_tokens: 1000,
},
settings: {
cpu_threads: 4,
@ -183,17 +203,16 @@ describe('useUpdateModelParameters', () => {
ngl: 12,
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should update valid model parameters only', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockInvalidParameters: any = {
params: {
// Inference
@ -225,8 +244,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
@ -235,12 +254,14 @@ describe('useUpdateModelParameters', () => {
mockInvalidParameters
)
})
jest.runAllTimers()
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
id: 'assistant-1',
model: {
engine: 'nitro',
id: 'model-1',
parameters: {
stop: ['<eos>'],
top_k: 0.7,
@ -260,55 +281,6 @@ describe('useUpdateModelParameters', () => {
mmproj: 'mmproj',
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should handle missing modelId and engine gracefully', async () => {
const mockParametersWithoutModelIdAndEngine: any = {
params: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockParametersWithoutModelIdAndEngine
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
settings: {},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
})

View File

@ -12,7 +12,10 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { extensionManager } from '@/extension'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
getActiveThreadModelParamsAtom,
@ -29,11 +32,28 @@ export type UpdateModelParameter = {
export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
const [selectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateAssistantExtension = (
threadId: string,
assistant: ThreadAssistantInfo
) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThreadAssistant(threadId, assistant)
}
const updateAssistantCallback = useDebouncedCallback(
updateAssistantExtension,
300
)
const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => {
if (!activeAssistant) return
const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId
? toUpdateSettings
@ -48,30 +68,34 @@ export default function useUpdateModelParameters() {
setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = extractInferenceParams(updatedModelParams)
const settingParams = extractModelLoadParams(updatedModelParams)
const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {
assistant.model.parameters = runtimeParams
assistant.model.settings = settingParams
if (selectedModel) {
assistant.model.id = settings.modelId ?? selectedModel?.id
assistant.model.engine = settings.engine ?? selectedModel?.engine
}
return assistant
}
)
// update thread
const updatedThread: Thread = {
...thread,
assistants,
}
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
const assistantInfo = {
...activeAssistant,
model: {
...activeAssistant?.model,
parameters: runtimeParams,
settings: settingParams,
id: settings.modelId ?? selectedModel?.id ?? activeAssistant.model.id,
engine:
settings.engine ??
selectedModel?.engine ??
activeAssistant.model.engine,
},
[activeModelParams, selectedModel, setThreadModelParams]
}
setActiveAssistant(assistantInfo)
updateAssistantCallback(thread.id, assistantInfo)
},
[
activeAssistant,
selectedModel?.parameters,
selectedModel?.settings,
selectedModel?.id,
selectedModel?.engine,
activeModelParams,
setThreadModelParams,
setActiveAssistant,
updateAssistantCallback,
]
)
const processStopWords = (params: ModelParams): ModelParams => {

View File

@ -37,5 +37,5 @@ const config = {
// module.exports = createJestConfig(config)
module.exports = async () => ({
...(await createJestConfig(config)()),
transformIgnorePatterns: ['/node_modules/(?!(layerr)/)'],
transformIgnorePatterns: ['/node_modules/(?!(layerr|nanoid|@uppy|preact)/)'],
})

View File

@ -35,7 +35,7 @@ const nextConfig = {
POSTHOG_HOST: JSON.stringify(process.env.POSTHOG_HOST),
ANALYTICS_HOST: JSON.stringify(process.env.ANALYTICS_HOST),
API_BASE_URL: JSON.stringify(
process.env.API_BASE_URL ?? 'http://localhost:1337'
process.env.API_BASE_URL ?? 'http://127.0.0.1:39291'
),
isMac: process.platform === 'darwin',
isWindows: process.platform === 'win32',

View File

@ -17,6 +17,9 @@
"@janhq/core": "link:./core",
"@janhq/joi": "link:./joi",
"@tanstack/react-virtual": "^3.10.9",
"@uppy/core": "^4.3.0",
"@uppy/react": "^4.0.4",
"@uppy/xhr-upload": "^4.2.3",
"autoprefixer": "10.4.16",
"class-variance-authority": "^0.7.0",
"framer-motion": "^10.16.4",

View File

@ -7,6 +7,8 @@ import { useAtomValue, useSetAtom } from 'jotai'
import { useActiveModel } from '@/hooks/useActiveModel'
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import AssistantSetting from './index'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
jest.mock('jotai', () => {
const originalModule = jest.requireActual('jotai')
@ -68,6 +70,7 @@ describe('AssistantSetting Component', () => {
beforeEach(() => {
jest.clearAllMocks()
jest.useFakeTimers()
})
test('renders AssistantSetting component with proper data', async () => {
@ -75,7 +78,14 @@ describe('AssistantSetting Component', () => {
;(useSetAtom as jest.Mock).mockImplementationOnce(
() => setEngineParamsUpdate
)
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
;(useAtomValue as jest.Mock).mockImplementation((atom) => {
switch (atom) {
case activeThreadAtom:
return mockActiveThread
case activeAssistantAtom:
return {}
}
})
const updateThreadMetadata = jest.fn()
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() })
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
@ -98,7 +108,14 @@ describe('AssistantSetting Component', () => {
const setEngineParamsUpdate = jest.fn()
const updateThreadMetadata = jest.fn()
const stopModel = jest.fn()
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
;(useAtomValue as jest.Mock).mockImplementation((atom) => {
switch (atom) {
case activeThreadAtom:
return mockActiveThread
case activeAssistantAtom:
return {}
}
})
;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate)
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel })
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({

View File

@ -8,6 +8,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
activeThreadAtom,
engineParamsUpdateAtom,
@ -19,13 +20,14 @@ type Props = {
const AssistantSetting: React.FC<Props> = ({ componentData }) => {
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { updateThreadMetadata } = useCreateNewThread()
const { stopModel } = useActiveModel()
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const onValueChanged = useCallback(
(key: string, value: string | number | boolean | string[]) => {
if (!activeThread) return
if (!activeThread || !activeAssistant) return
const shouldReloadModel =
componentData.find((x) => x.key === key)?.requireModelReload ?? false
if (shouldReloadModel) {
@ -34,40 +36,40 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
}
if (
activeThread.assistants[0].tools &&
activeAssistant?.tools &&
(key === 'chunk_overlap' || key === 'chunk_size')
) {
if (
activeThread.assistants[0].tools[0]?.settings?.chunk_size <
activeThread.assistants[0].tools[0]?.settings?.chunk_overlap
activeAssistant.tools[0]?.settings?.chunk_size <
activeAssistant.tools[0]?.settings?.chunk_overlap
) {
activeThread.assistants[0].tools[0].settings.chunk_overlap =
activeThread.assistants[0].tools[0].settings.chunk_size
activeAssistant.tools[0].settings.chunk_overlap =
activeAssistant.tools[0].settings.chunk_size
}
if (
key === 'chunk_size' &&
value < activeThread.assistants[0].tools[0].settings?.chunk_overlap
value < activeAssistant.tools[0].settings?.chunk_overlap
) {
activeThread.assistants[0].tools[0].settings.chunk_overlap = value
activeAssistant.tools[0].settings.chunk_overlap = value
} else if (
key === 'chunk_overlap' &&
value > activeThread.assistants[0].tools[0].settings?.chunk_size
value > activeAssistant.tools[0].settings?.chunk_size
) {
activeThread.assistants[0].tools[0].settings.chunk_size = value
activeAssistant.tools[0].settings.chunk_size = value
}
}
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
...activeAssistant,
tools: [
{
type: 'retrieval',
enabled: true,
settings: {
...(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings),
...(activeAssistant.tools &&
activeAssistant.tools[0]?.settings),
[key]: value,
},
},
@ -77,6 +79,7 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
})
},
[
activeAssistant,
activeThread,
componentData,
setEngineParamsUpdate,

View File

@ -24,6 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
import useSendChatMessage from '@/hooks/useSendChatMessage'
import { uploader } from '@/utils/file'
import { isLocalEngine } from '@/utils/modelEngine'
import FileUploadPreview from '../FileUploadPreview'
@ -33,6 +34,7 @@ import RichTextEditor from './RichTextEditor'
import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
@ -67,8 +69,10 @@ const ChatInput = () => {
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
const threadStates = useAtomValue(threadStatesAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { stopInference } = useActiveModel()
const upload = uploader()
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
activeTabThreadRightPanelAtom
)
@ -102,18 +106,26 @@ const ChatInput = () => {
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0]
if (!file) return
setFileUpload([{ file: file, type: 'pdf' }])
upload.addFile(file)
upload.upload().then((data) => {
setFileUpload({
file: file,
type: 'pdf',
id: data?.successful?.[0]?.response?.body?.id,
name: data?.successful?.[0]?.response?.body?.filename,
})
})
}
const handleImageChange = (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0]
if (!file) return
setFileUpload([{ file: file, type: 'image' }])
setFileUpload({ file: file, type: 'image' })
}
const renderPreview = (fileUpload: any) => {
if (fileUpload.length > 0) {
if (fileUpload[0].type === 'image') {
if (fileUpload) {
if (fileUpload.type === 'image') {
return <ImageUploadPreview file={fileUpload[0].file} />
} else {
return <FileUploadPreview />
@ -130,7 +142,7 @@ const ChatInput = () => {
'relative mb-1 max-h-[400px] resize-none rounded-lg border border-[hsla(var(--app-border))] p-3 pr-20',
'focus-within:outline-none focus-visible:outline-0 focus-visible:ring-1 focus-visible:ring-[hsla(var(--primary-bg))] focus-visible:ring-offset-0',
'overflow-y-auto',
fileUpload.length && 'rounded-t-none',
fileUpload && 'rounded-t-none',
experimentalFeature && 'pl-10',
activeSettingInputBox && 'pb-14 pr-16'
)}
@ -152,10 +164,10 @@ const ChatInput = () => {
className="absolute left-3 top-2.5"
onClick={(e) => {
if (
fileUpload.length > 0 ||
(activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings?.vision_model)
!!fileUpload ||
(activeAssistant?.tools &&
!activeAssistant?.tools[0]?.enabled &&
!activeAssistant?.model.settings?.vision_model)
) {
e.stopPropagation()
} else {
@ -171,26 +183,24 @@ const ChatInput = () => {
}
disabled={
isModelSupportRagAndTools &&
activeThread?.assistants[0].tools &&
activeThread?.assistants[0].tools[0]?.enabled
activeAssistant?.tools &&
activeAssistant?.tools[0]?.enabled
}
content={
<>
{fileUpload.length > 0 ||
(activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings
?.vision_model && (
{!!fileUpload ||
(activeAssistant?.tools &&
!activeAssistant?.tools[0]?.enabled &&
!activeAssistant?.model.settings?.vision_model && (
<>
{fileUpload.length !== 0 && (
{!!fileUpload && (
<span>
Currently, we only support 1 attachment at the same
time.
</span>
)}
{activeThread?.assistants[0].tools &&
activeThread?.assistants[0].tools[0]?.enabled ===
false &&
{activeAssistant?.tools &&
activeAssistant?.tools[0]?.enabled === false &&
isModelSupportRagAndTools && (
<span>
Turn on Retrieval in Tools settings to use this
@ -221,14 +231,12 @@ const ChatInput = () => {
<li
className={twMerge(
'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]',
activeThread?.assistants[0].model.settings?.vision_model
activeAssistant?.model.settings?.vision_model
? 'cursor-pointer'
: 'cursor-not-allowed opacity-50'
)}
onClick={() => {
if (
activeThread?.assistants[0].model.settings?.vision_model
) {
if (activeAssistant?.model.settings?.vision_model) {
imageInputRef.current?.click()
setShowAttacmentMenus(false)
}
@ -239,9 +247,7 @@ const ChatInput = () => {
</li>
}
content="This feature only supports multimodal models."
disabled={
activeThread?.assistants[0].model.settings?.vision_model
}
disabled={activeAssistant?.model.settings?.vision_model}
/>
<Tooltip
side="bottom"
@ -261,8 +267,8 @@ const ChatInput = () => {
</li>
}
content={
(!activeThread?.assistants[0].tools ||
!activeThread?.assistants[0].tools[0]?.enabled) && (
(!activeAssistant?.tools ||
!activeAssistant?.tools[0]?.enabled) && (
<span>
Turn on Retrieval in Assistant Settings to use this
feature.

View File

@ -72,7 +72,8 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
}, [editPrompt])
useEffect(() => {
setEditPrompt(message.content[0]?.text?.value)
if (message.content?.[0]?.text?.value)
setEditPrompt(message.content[0].text.value)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
@ -80,19 +81,17 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
setEditMessage('')
const messageIdx = messages.findIndex((msg) => msg.id === message.id)
const newMessages = messages.slice(0, messageIdx)
if (activeThread) {
setMessages(activeThread.id, newMessages)
await extensionManager
const toDeleteMessages = messages.slice(messageIdx)
const threadId = messages[0].thread_id
await Promise.all(
toDeleteMessages.map(async (message) =>
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(
activeThread.id,
// Remove all of the messages below this
newMessages
?.deleteMessage(message.thread_id, message.id)
)
.then(() => {
sendChatMessage(editPrompt, newMessages)
})
}
)
setMessages(threadId, newMessages)
sendChatMessage(editPrompt, false, newMessages)
}
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {

View File

@ -15,21 +15,22 @@ const FileUploadPreview = () => {
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const onDeleteClick = () => {
setFileUpload([])
setFileUpload(undefined)
setCurrentPrompt('')
}
return (
<div className="flex flex-col rounded-t-lg border border-b-0 border-[hsla(var(--app-border))] p-4">
{!!fileUpload && (
<div className="bg-secondary relative inline-flex w-60 space-x-3 rounded-lg p-4">
<Icon type={fileUpload[0].type} />
<Icon type={fileUpload?.type} />
<div className="w-full">
<h6 className="line-clamp-1 w-3/4 truncate font-medium">
{fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')}
{fileUpload?.file.name.replaceAll(/[-._]/g, ' ')}
</h6>
<p className="text-[hsla(var(--text-secondary)]">
{toGibibytes(fileUpload[0].file.size)}
{toGibibytes(fileUpload?.file.size)}
</p>
</div>
@ -40,6 +41,7 @@ const FileUploadPreview = () => {
<XIcon size={14} className="text-background" />
</div>
</div>
)}
</div>
)
}

View File

@ -29,7 +29,7 @@ const ImageUploadPreview: React.FC<Props> = ({ file }) => {
}
const onDeleteClick = () => {
setFileUpload([])
setFileUpload(undefined)
setCurrentPrompt('')
}

View File

@ -10,15 +10,15 @@ import { MainViewState } from '@/constants/screens'
import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const LoadModelError = () => {
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
const loadModelError = useAtomValue(loadModelErrorAtom)
const setMainState = useSetAtom(mainViewStateAtom)
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const ErrorMessage = () => {
if (
@ -33,9 +33,9 @@ const LoadModelError = () => {
className="cursor-pointer font-medium text-[hsla(var(--app-link))]"
onClick={() => {
setMainState(MainViewState.Settings)
if (activeThread?.assistants[0]?.model.engine) {
if (activeAssistant?.model.engine) {
const engine = EngineManager.instance().get(
activeThread.assistants[0].model.engine
activeAssistant.model.engine
)
engine?.name && setSelectedSettingScreen(engine.name)
}

View File

@ -55,15 +55,11 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
.slice(-1)[0]
if (thread) {
// Should also delete error messages to clear out the error state
// TODO: Should also delete error messages to clear out the error state
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.writeMessages(
thread.id,
messages.filter(
(msg) => msg.id !== message.id && msg.status !== MessageStatus.Error
)
)
?.deleteMessage(thread.id, message.id)
.catch(console.error)
const updatedThread: Thread = {
...thread,
@ -74,7 +70,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
)[
messages.filter((msg) => msg.role === ChatCompletionRole.Assistant)
.length - 1
]?.content[0]?.text.value,
]?.content[0]?.text?.value,
},
}
@ -89,10 +85,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
setEditMessage(message.id ?? '')
}
const onRegenerateClick = async () => {
resendChatMessage(message)
}
if (message.status === MessageStatus.Pending) return null
return (
@ -118,11 +110,10 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
{message.id === messages[messages.length - 1]?.id &&
messages[messages.length - 1].status !== MessageStatus.Error &&
messages[messages.length - 1].content[0]?.type !==
ContentType.Pdf && (
!messages[messages.length - 1].attachments?.length && (
<div
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
onClick={onRegenerateClick}
onClick={resendChatMessage}
>
<Tooltip
trigger={

View File

@ -11,15 +11,7 @@ import { openFileTitle } from '@/utils/titleUtils'
import Icon from '../FileUploadPreview/Icon'
const DocMessage = ({
id,
name,
size,
}: {
id: string
name?: string
size?: number
}) => {
const DocMessage = ({ id, name }: { id: string; name?: string }) => {
const { onViewFile, onViewFileContainer } = usePath()
return (
@ -44,9 +36,9 @@ const DocMessage = ({
<h6 className="line-clamp-1 w-4/5 font-medium">
{name?.replaceAll(/[-._]/g, ' ')}
</h6>
<p className="text-[hsla(var(--text-secondary)]">
{/* <p className="text-[hsla(var(--text-secondary)]">
{toGibibytes(Number(size))}
</p>
</p> */}
</div>
</div>
)

View File

@ -1,6 +1,5 @@
import { memo, useMemo } from 'react'
import { memo } from 'react'
import { ThreadContent } from '@janhq/core'
import { Tooltip } from '@janhq/joi'
import { FolderOpenIcon } from 'lucide-react'
@ -11,21 +10,13 @@ import { openFileTitle } from '@/utils/titleUtils'
import { RelativeImage } from '../TextMessage/RelativeImage'
const ImageMessage = ({ content }: { content: ThreadContent }) => {
const ImageMessage = ({ image }: { image: string }) => {
const { onViewFile, onViewFileContainer } = usePath()
const annotation = useMemo(
() => content?.text?.annotations[0] ?? '',
[content]
)
return (
<div className="group/image relative mb-2 inline-flex cursor-pointer overflow-hidden rounded-xl">
<div className="left-0 top-0 z-20 h-full w-full group-hover/image:inline-block">
<RelativeImage
src={annotation}
onClick={() => onViewFile(annotation)}
/>
<RelativeImage src={image} onClick={() => onViewFile(image)} />
</div>
<Tooltip
trigger={

View File

@ -17,11 +17,11 @@ import DocMessage from './DocMessage'
import ImageMessage from './ImageMessage'
import { MarkdownTextMessage } from './MarkdownTextMessage'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
editMessageAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const MessageContainer: React.FC<
ThreadMessage & { isCurrentMessage: boolean }
@ -29,18 +29,23 @@ const MessageContainer: React.FC<
const isUser = props.role === ChatCompletionRole.User
const isSystem = props.role === ChatCompletionRole.System
const editMessage = useAtomValue(editMessageAtom)
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const tokenSpeed = useAtomValue(tokenSpeedAtom)
const text = useMemo(
() => props.content[0]?.text?.value ?? '',
() =>
props.content.find((e) => e.type === ContentType.Text)?.text?.value ?? '',
[props.content]
)
const messageType = useMemo(
() => props.content[0]?.type ?? '',
const image = useMemo(
() =>
props.content.find((e) => e.type === ContentType.Image)?.image_url?.url,
[props.content]
)
const attachedFile = useMemo(() => 'attachments' in props, [props])
return (
<div className="group relative mx-auto max-w-[700px] p-4">
<div
@ -75,10 +80,10 @@ const MessageContainer: React.FC<
>
{isUser
? props.role
: (activeThread?.assistants[0].assistant_name ?? props.role)}
: (activeAssistant?.assistant_name ?? props.role)}
</div>
<p className="text-xs font-medium text-gray-400">
{displayDate(props.created)}
{props.created && displayDate(props.created ?? new Date())}
</p>
</div>
@ -111,16 +116,8 @@ const MessageContainer: React.FC<
)}
>
<>
{messageType === ContentType.Image && (
<ImageMessage content={props.content[0]} />
)}
{messageType === ContentType.Pdf && (
<DocMessage
id={props.id}
name={props.content[0]?.text?.name}
size={props.content[0]?.text?.size}
/>
)}
{image && <ImageMessage image={image} />}
{attachedFile && <DocMessage id={props.id} name={props.id} />}
{editMessage === props.id ? (
<div>

View File

@ -22,11 +22,14 @@ import { reloadModelAtom } from '@/hooks/useSendChatMessage'
import ChatBody from '@/screens/Thread/ThreadCenterPanel/ChatBody'
import { uploader } from '@/utils/file'
import ChatInput from './ChatInput'
import RequestDownloadModel from './RequestDownloadModel'
import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
import {
@ -55,9 +58,9 @@ const ThreadCenterPanel = () => {
const setFileUpload = useSetAtom(fileUploadAtom)
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const activeThread = useAtomValue(activeThreadAtom)
const acceptedFormat: Accept = activeThread?.assistants[0].model.settings
?.vision_model
const activeAssistant = useAtomValue(activeAssistantAtom)
const upload = uploader()
const acceptedFormat: Accept = activeAssistant?.model.settings?.vision_model
? {
'application/pdf': ['.pdf'],
'image/jpeg': ['.jpeg'],
@ -78,14 +81,13 @@ const ThreadCenterPanel = () => {
if (!experimentalFeature) return
if (
e.dataTransfer.items.length === 1 &&
((activeThread?.assistants[0].tools &&
activeThread?.assistants[0].tools[0]?.enabled) ||
activeThread?.assistants[0].model.settings?.vision_model)
((activeAssistant?.tools && activeAssistant?.tools[0]?.enabled) ||
activeAssistant?.model.settings?.vision_model)
) {
setDragOver(true)
} else if (
activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled
activeAssistant?.tools &&
!activeAssistant?.tools[0]?.enabled
) {
setDragRejected({ code: 'retrieval-off' })
} else {
@ -93,27 +95,36 @@ const ThreadCenterPanel = () => {
}
},
onDragLeave: () => setDragOver(false),
onDrop: (files, rejectFiles) => {
onDrop: async (files, rejectFiles) => {
// Retrieval file drag and drop is experimental feature
if (!experimentalFeature) return
if (
!files ||
files.length !== 1 ||
rejectFiles.length !== 0 ||
(activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled &&
!activeThread?.assistants[0].model.settings?.vision_model)
(activeAssistant?.tools &&
!activeAssistant?.tools[0]?.enabled &&
!activeAssistant?.model.settings?.vision_model)
)
return
const imageType = files[0]?.type.includes('image')
setFileUpload([{ file: files[0], type: imageType ? 'image' : 'pdf' }])
if (imageType) {
setFileUpload({ file: files[0], type: 'image' })
} else {
upload.addFile(files[0])
upload.upload().then((data) => {
setFileUpload({
file: files[0],
type: imageType ? 'image' : 'pdf',
id: data?.successful?.[0]?.response?.body?.id,
name: data?.successful?.[0]?.response?.body?.filename,
})
})
}
setDragOver(false)
},
onDropRejected: (e) => {
if (
activeThread?.assistants[0].tools &&
!activeThread?.assistants[0].tools[0]?.enabled
) {
if (activeAssistant?.tools && !activeAssistant?.tools[0]?.enabled) {
setDragRejected({ code: 'retrieval-off' })
} else {
setDragRejected({ code: e[0].errors[0].code })
@ -186,8 +197,7 @@ const ThreadCenterPanel = () => {
<h6 className="font-bold">
{isDragReject
? `Currently, we only support 1 attachment at the same time with ${
activeThread?.assistants[0].model.settings
?.vision_model
activeAssistant?.model.settings?.vision_model
? 'PDF, JPEG, JPG, PNG'
: 'PDF'
} format`
@ -195,7 +205,7 @@ const ThreadCenterPanel = () => {
</h6>
{!isDragReject && (
<p className="mt-2">
{activeThread?.assistants[0].model.settings?.vision_model
{activeAssistant?.model.settings?.vision_model
? 'PDF, JPEG, JPG, PNG'
: 'PDF'}
</p>

View File

@ -15,13 +15,15 @@ const ModalEditTitleThread = () => {
const [modalActionThread, setModalActionThread] = useAtom(
modalActionThreadAtom
)
const [title, setTitle] = useState(modalActionThread.thread?.title as string)
const [title, setTitle] = useState(
modalActionThread.thread?.metadata?.title as string
)
useLayoutEffect(() => {
if (modalActionThread.thread?.title) {
setTitle(modalActionThread.thread?.title)
if (modalActionThread.thread?.metadata?.title) {
setTitle(modalActionThread.thread?.metadata?.title as string)
}
}, [modalActionThread.thread?.title])
}, [modalActionThread.thread?.metadata])
const onUpdateTitle = useCallback(
(e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
@ -30,6 +32,10 @@ const ModalEditTitleThread = () => {
updateThreadMetadata({
...modalActionThread?.thread,
title: title || 'New Thread',
metadata: {
...modalActionThread?.thread.metadata,
title: title || 'New Thread',
},
})
},
[modalActionThread?.thread, title, updateThreadMetadata]

View File

@ -20,7 +20,10 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useRecommendedModel from '@/hooks/useRecommendedModel'
import useSetActiveThread from '@/hooks/useSetActiveThread'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import {
activeAssistantAtom,
assistantsAtom,
} from '@/helpers/atoms/Assistant.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import {
@ -34,6 +37,7 @@ import {
const ThreadLeftPanel = () => {
const threads = useAtomValue(threadsAtom)
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { setActiveThread } = useSetActiveThread()
const assistants = useAtomValue(assistantsAtom)
const threadDataReady = useAtomValue(threadDataReadyAtom)
@ -67,6 +71,7 @@ const ThreadLeftPanel = () => {
useEffect(() => {
if (
threadDataReady &&
activeAssistant &&
assistants.length > 0 &&
threads.length === 0 &&
downloadedModels.length > 0
@ -75,7 +80,10 @@ const ThreadLeftPanel = () => {
(model) => model.engine === InferenceEngine.cortex_llamacpp
)
const selectedModel = model[0] || recommendedModel
requestCreateNewThread(assistants[0], selectedModel)
requestCreateNewThread(
{ ...assistants[0], ...activeAssistant },
selectedModel
)
} else if (threadDataReady && !activeThreadId) {
setActiveThread(threads[0])
}
@ -88,6 +96,7 @@ const ThreadLeftPanel = () => {
setActiveThread,
recommendedModel,
downloadedModels,
activeAssistant,
])
const onContextMenu = (event: React.MouseEvent, thread: Thread) => {
@ -138,7 +147,7 @@ const ThreadLeftPanel = () => {
activeThreadId && 'font-medium'
)}
>
{thread.title}
{thread.title ?? thread.metadata?.title}
</h1>
</div>
<div

View File

@ -14,48 +14,54 @@ import AssistantSetting from '@/screens/Thread/ThreadCenterPanel/AssistantSettin
import { getConfigurationsData } from '@/utils/componentSettings'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const Tools = () => {
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const { updateThreadMetadata } = useCreateNewThread()
const { recommendedModel, downloadedModels } = useRecommendedModel()
const componentDataAssistantSetting = getConfigurationsData(
(activeThread?.assistants[0]?.tools &&
activeThread?.assistants[0]?.tools[0]?.settings) ??
{}
(activeAssistant?.tools && activeAssistant?.tools[0]?.settings) ?? {}
)
useEffect(() => {
if (!activeThread) return
let model = downloadedModels.find(
(model) => model.id === activeThread.assistants[0].model.id
(model) => model.id === activeAssistant?.model.id
)
if (!model) {
model = recommendedModel
}
setSelectedModel(model)
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
}, [
recommendedModel,
activeThread,
downloadedModels,
setSelectedModel,
activeAssistant?.model.id,
])
const onRetrievalSwitchUpdate = useCallback(
(enabled: boolean) => {
if (!activeThread) return
if (!activeThread || !activeAssistant) return
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
...activeAssistant,
tools: [
{
type: 'retrieval',
enabled: enabled,
settings:
(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings) ??
(activeAssistant.tools &&
activeAssistant.tools[0]?.settings) ??
{},
},
],
@ -63,25 +69,25 @@ const Tools = () => {
],
})
},
[activeThread, updateThreadMetadata]
[activeAssistant, activeThread, updateThreadMetadata]
)
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
(enabled: boolean) => {
if (!activeThread) return
if (!activeThread || !activeAssistant) return
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
...activeAssistant,
tools: [
{
type: 'retrieval',
enabled: true,
useTimeWeightedRetriever: enabled,
settings:
(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings) ??
(activeAssistant.tools &&
activeAssistant.tools[0]?.settings) ??
{},
},
],
@ -89,15 +95,14 @@ const Tools = () => {
],
})
},
[activeThread, updateThreadMetadata]
[activeAssistant, activeThread, updateThreadMetadata]
)
if (!experimentalFeature) return null
return (
<Fragment>
{activeThread?.assistants[0]?.tools &&
componentDataAssistantSetting.length > 0 && (
{activeAssistant?.tools && componentDataAssistantSetting.length > 0 && (
<div className="p-4">
<div className="mb-2">
<div className="flex items-center justify-between">
@ -122,13 +127,13 @@ const Tools = () => {
<div className="flex items-center justify-between">
<Switch
name="retrieval"
checked={activeThread?.assistants[0].tools[0].enabled}
checked={activeAssistant?.tools[0].enabled}
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
/>
</div>
</div>
</div>
{activeThread?.assistants[0]?.tools[0].enabled && (
{activeAssistant?.tools[0].enabled && (
<div className="pb-4 pt-2">
<div className="mb-4">
<div className="item-center mb-2 flex">
@ -155,11 +160,7 @@ const Tools = () => {
/>
</div>
<div className="w-full">
<Input
value={selectedModel?.name || ''}
disabled
readOnly
/>
<Input value={selectedModel?.name || ''} disabled readOnly />
</div>
</div>
<div className="mb-4">
@ -214,8 +215,8 @@ const Tools = () => {
<Switch
name="use-time-weighted-retriever"
checked={
activeThread?.assistants[0].tools[0]
.useTimeWeightedRetriever || false
activeAssistant?.tools[0].useTimeWeightedRetriever ||
false
}
onChange={(e) =>
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
@ -224,9 +225,7 @@ const Tools = () => {
</div>
</div>
</div>
<AssistantSetting
componentData={componentDataAssistantSetting}
/>
<AssistantSetting componentData={componentDataAssistantSetting} />
</div>
)}
</div>

View File

@ -38,6 +38,7 @@ import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
activeThreadAtom,
@ -53,6 +54,7 @@ const ENGINE_SETTINGS = 'Engine Settings'
const ThreadRightPanel = () => {
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
@ -154,18 +156,18 @@ const ThreadRightPanel = () => {
const onAssistantInstructionChanged = useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
if (activeThread)
if (activeThread && activeAssistant)
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
...activeAssistant,
instructions: e.target.value || '',
},
],
})
},
[activeThread, updateThreadMetadata]
[activeAssistant, activeThread, updateThreadMetadata]
)
const resetModel = useDebouncedCallback(() => {
@ -174,9 +176,7 @@ const ThreadRightPanel = () => {
const onValueChanged = useCallback(
(key: string, value: string | number | boolean | string[]) => {
if (!activeThread) {
return
}
if (!activeThread || !activeAssistant) return
setEngineParamsUpdate(true)
resetModel()
@ -186,32 +186,38 @@ const ThreadRightPanel = () => {
})
if (
activeThread.assistants[0].model.parameters?.max_tokens &&
activeThread.assistants[0].model.settings?.ctx_len
activeAssistant.model.parameters?.max_tokens &&
activeAssistant.model.settings?.ctx_len
) {
if (
key === 'max_tokens' &&
Number(value) > activeThread.assistants[0].model.settings.ctx_len
Number(value) > activeAssistant.model.settings.ctx_len
) {
updateModelParameter(activeThread, {
params: {
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
max_tokens: activeAssistant.model.settings.ctx_len,
},
})
}
if (
key === 'ctx_len' &&
Number(value) < activeThread.assistants[0].model.parameters.max_tokens
Number(value) < activeAssistant.model.parameters.max_tokens
) {
updateModelParameter(activeThread, {
params: {
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
max_tokens: activeAssistant.model.settings.ctx_len,
},
})
}
}
},
[activeThread, resetModel, setEngineParamsUpdate, updateModelParameter]
[
activeAssistant,
activeThread,
resetModel,
setEngineParamsUpdate,
updateModelParameter,
]
)
if (!activeThread) {
@ -250,7 +256,7 @@ const ThreadRightPanel = () => {
<TextArea
id="assistant-instructions"
placeholder="Eg. You are a helpful assistant."
value={activeThread?.assistants[0].instructions ?? ''}
value={activeAssistant?.instructions ?? ''}
autoResize
onChange={onAssistantInstructionChanged}
/>

View File

@ -5,16 +5,19 @@ import { useStarterScreen } from '../../hooks/useStarterScreen'
import '@testing-library/jest-dom'
global.ResizeObserver = class {
observe() {}
unobserve() {}
disconnect() {}
observe() { }
unobserve() { }
disconnect() { }
}
// Mock the useStarterScreen hook
jest.mock('@/hooks/useStarterScreen')
// @ts-ignore
global.API_BASE_URL = 'http://localhost:3000'
describe('ThreadScreen', () => {
it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
;(useStarterScreen as jest.Mock).mockReturnValue({
; (useStarterScreen as jest.Mock).mockReturnValue({
isShowStarterScreen: true,
extensionHasSettings: false,
})
@ -24,7 +27,7 @@ describe('ThreadScreen', () => {
})
it('renders Thread panels when isShowStarterScreen is false', () => {
;(useStarterScreen as jest.Mock).mockReturnValue({
; (useStarterScreen as jest.Mock).mockReturnValue({
isShowStarterScreen: false,
extensionHasSettings: false,
})

2
web/types/file.d.ts vendored
View File

@ -3,4 +3,6 @@ export type FileType = 'image' | 'pdf'
export type FileInfo = {
file: File
type: FileType
id?: string
name?: string
}

View File

@ -1,4 +1,6 @@
import { baseName } from '@janhq/core'
import Uppy from '@uppy/core'
import XHR from '@uppy/xhr-upload'
export type FilePathWithSize = {
path: string
@ -27,3 +29,21 @@ export const getFileInfoFromFile = async (
}
return result
}
/**
* This function creates an Uppy instance with XHR plugin for file upload to the server.
* @returns Uppy instance
*/
export const uploader = () => {
const uppy = new Uppy().use(XHR, {
endpoint: `${API_BASE_URL}/v1/files`,
method: 'POST',
fieldName: 'file',
formData: true,
limit: 1,
})
uppy.setMeta({
purpose: 'assistants',
})
return uppy
}

View File

@ -15,7 +15,7 @@ import { ulid } from 'ulidx'
import { Stack } from '@/utils/Stack'
import { FileType } from '@/types/file'
import { FileInfo, FileType } from '@/types/file'
export class MessageRequestBuilder {
msgId: string
@ -38,7 +38,7 @@ export class MessageRequestBuilder {
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '.',
content: msg.content[0]?.text?.value ?? '.',
}))
}
@ -46,11 +46,11 @@ export class MessageRequestBuilder {
pushMessage(
message: string,
base64Blob: string | undefined,
fileContentType: FileType
fileInfo?: FileInfo
) {
if (base64Blob && fileContentType === 'pdf')
return this.addDocMessage(message)
else if (base64Blob && fileContentType === 'image') {
if (base64Blob && fileInfo?.type === 'pdf')
return this.addDocMessage(message, fileInfo?.name)
else if (base64Blob && fileInfo?.type === 'image') {
return this.addImageMessage(message, base64Blob)
}
this.messages = [
@ -77,7 +77,7 @@ export class MessageRequestBuilder {
}
// Chainable
addDocMessage(prompt: string) {
addDocMessage(prompt: string, name?: string) {
const message: ChatCompletionMessage = {
role: ChatCompletionRole.User,
content: [
@ -88,7 +88,7 @@ export class MessageRequestBuilder {
{
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
url: name ?? `${this.msgId}.pdf`,
},
},
] as ChatCompletionMessageContent,
@ -163,6 +163,7 @@ export class MessageRequestBuilder {
return {
id: this.msgId,
type: this.type,
attachments: [],
threadId: this.thread.id,
messages: this.normalizeMessages(this.messages),
model: this.model,

View File

@ -1,16 +1,19 @@
import {
ChatCompletionRole,
MessageRequestType,
MessageStatus,
} from '@janhq/core'
import { ChatCompletionRole, MessageStatus } from '@janhq/core'
import { ThreadMessageBuilder } from './threadMessageBuilder'
import { MessageRequestBuilder } from './messageRequestBuilder'
import { ThreadMessageBuilder } from './threadMessageBuilder'
import { MessageRequestBuilder } from './messageRequestBuilder'
import { ContentType } from '@janhq/core';
describe('ThreadMessageBuilder', () => {
import { ContentType } from '@janhq/core'
describe('ThreadMessageBuilder', () => {
it('testBuildMethod', () => {
const msgRequest = new MessageRequestBuilder(
'type',
{ model: 'model' },
{ id: 'thread-id' },
MessageRequestType.Thread,
{ model: 'model' } as any,
{ id: 'thread-id' } as any,
[]
)
const builder = new ThreadMessageBuilder(msgRequest)
@ -25,18 +28,18 @@ import { ContentType } from '@janhq/core';
expect(result.object).toBe('thread.message')
expect(result.content).toEqual([])
})
})
})
it('testPushMessageWithPromptOnly', () => {
it('testPushMessageWithPromptOnly', () => {
const msgRequest = new MessageRequestBuilder(
'type',
{ model: 'model' },
{ id: 'thread-id' },
MessageRequestType.Thread,
{ model: 'model' } as any,
{ id: 'thread-id' } as any,
[]
);
const builder = new ThreadMessageBuilder(msgRequest);
const prompt = 'test prompt';
builder.pushMessage(prompt, undefined, []);
)
const builder = new ThreadMessageBuilder(msgRequest)
const prompt = 'test prompt'
builder.pushMessage(prompt, undefined, undefined)
expect(builder.content).toEqual([
{
type: ContentType.Text,
@ -45,56 +48,53 @@ import { ContentType } from '@janhq/core';
annotations: [],
},
},
]);
});
])
})
it('testPushMessageWithPdf', () => {
it('testPushMessageWithPdf', () => {
const msgRequest = new MessageRequestBuilder(
'type',
{ model: 'model' },
{ id: 'thread-id' },
MessageRequestType.Thread,
{ model: 'model' } as any,
{ id: 'thread-id' } as any,
[]
);
const builder = new ThreadMessageBuilder(msgRequest);
const prompt = 'test prompt';
const base64 = 'test base64';
const fileUpload = [{ type: 'pdf', file: { name: 'test.pdf', size: 1000 } }];
builder.pushMessage(prompt, base64, fileUpload);
)
const builder = new ThreadMessageBuilder(msgRequest)
const prompt = 'test prompt'
const base64 = 'test base64'
const fileUpload = [
{ type: 'pdf', file: { name: 'test.pdf', size: 1000 } },
] as any
builder.pushMessage(prompt, base64, fileUpload)
expect(builder.content).toEqual([
{
type: ContentType.Pdf,
type: ContentType.Text,
text: {
value: prompt,
annotations: [base64],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
annotations: [],
},
},
]);
});
])
})
it('testPushMessageWithImage', () => {
it('testPushMessageWithImage', () => {
const msgRequest = new MessageRequestBuilder(
'type',
{ model: 'model' },
{ id: 'thread-id' },
MessageRequestType.Thread,
{ model: 'model' } as any,
{ id: 'thread-id' } as any,
[]
);
const builder = new ThreadMessageBuilder(msgRequest);
const prompt = 'test prompt';
const base64 = 'test base64';
const fileUpload = [{ type: 'image', file: { name: 'test.jpg', size: 1000 } }];
builder.pushMessage(prompt, base64, fileUpload);
)
const builder = new ThreadMessageBuilder(msgRequest)
const prompt = 'test prompt'
const base64 = 'test base64'
const fileUpload = [{ type: 'image', file: { name: 'test.jpg', size: 1000 } }]
builder.pushMessage(prompt, base64, fileUpload as any)
expect(builder.content).toEqual([
{
type: ContentType.Image,
type: ContentType.Text,
text: {
value: prompt,
annotations: [base64],
annotations: [],
},
},
]);
});
])
})

View File

@ -1,4 +1,5 @@
import {
Attachment,
ChatCompletionRole,
ContentType,
MessageStatus,
@ -14,6 +15,7 @@ export class ThreadMessageBuilder {
messageRequest: MessageRequestBuilder
content: ThreadContent[] = []
attachments: Attachment[] = []
constructor(messageRequest: MessageRequestBuilder) {
this.messageRequest = messageRequest
@ -24,6 +26,7 @@ export class ThreadMessageBuilder {
return {
id: this.messageRequest.msgId,
thread_id: this.messageRequest.thread.id,
attachments: this.attachments,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
@ -36,31 +39,9 @@ export class ThreadMessageBuilder {
pushMessage(
prompt: string,
base64: string | undefined,
fileUpload: FileInfo[]
fileUpload?: FileInfo
) {
if (base64 && fileUpload[0]?.type === 'image') {
this.content.push({
type: ContentType.Image,
text: {
value: prompt,
annotations: [base64],
},
})
}
if (base64 && fileUpload[0]?.type === 'pdf') {
this.content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
if (prompt && !base64) {
if (prompt) {
this.content.push({
type: ContentType.Text,
text: {
@ -69,6 +50,26 @@ export class ThreadMessageBuilder {
},
})
}
if (base64 && fileUpload?.type === 'image') {
this.content.push({
type: ContentType.Image,
image_url: {
url: base64,
},
})
}
if (base64 && fileUpload?.type === 'pdf') {
this.attachments.push({
file_id: fileUpload.id,
tools: [
{
type: 'file_search',
},
],
})
}
return this
}
}