From ba282d637e6142a607cc7bcf673280aaca83d363 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 10 Mar 2025 15:23:15 +0700 Subject: [PATCH] fix: anthropic response template correction --- .../extensions/engines/EngineManager.test.ts | 66 +++ .../engine-management-extension/package.json | 2 +- .../resources/anthropic.json | 2 +- .../engine-management-extension/src/index.ts | 7 +- .../src/index.test.ts | 452 ++++++++++++++++++ 5 files changed, 525 insertions(+), 4 deletions(-) create mode 100644 extensions/inference-cortex-extension/src/index.test.ts diff --git a/core/src/browser/extensions/engines/EngineManager.test.ts b/core/src/browser/extensions/engines/EngineManager.test.ts index c1f1fcb71..319dc792a 100644 --- a/core/src/browser/extensions/engines/EngineManager.test.ts +++ b/core/src/browser/extensions/engines/EngineManager.test.ts @@ -3,6 +3,7 @@ */ import { EngineManager } from './EngineManager' import { AIEngine } from './AIEngine' +import { InferenceEngine } from '../../../types' // @ts-ignore class MockAIEngine implements AIEngine { @@ -40,4 +41,69 @@ describe('EngineManager', () => { const retrievedEngine = engineManager.get('nonExistentProvider') expect(retrievedEngine).toBeUndefined() }) + + describe('cortex engine migration', () => { + test('should map nitro to cortex engine', () => { + const cortexEngine = new MockAIEngine(InferenceEngine.cortex) + // @ts-ignore + engineManager.register(cortexEngine) + + // @ts-ignore + const retrievedEngine = engineManager.get(InferenceEngine.nitro) + expect(retrievedEngine).toBe(cortexEngine) + }) + + test('should map cortex_llamacpp to cortex engine', () => { + const cortexEngine = new MockAIEngine(InferenceEngine.cortex) + // @ts-ignore + engineManager.register(cortexEngine) + + // @ts-ignore + const retrievedEngine = engineManager.get(InferenceEngine.cortex_llamacpp) + expect(retrievedEngine).toBe(cortexEngine) + }) + + test('should map cortex_onnx to cortex engine', () => { + const cortexEngine = new MockAIEngine(InferenceEngine.cortex) + // @ts-ignore + engineManager.register(cortexEngine) + + // @ts-ignore + const retrievedEngine = engineManager.get(InferenceEngine.cortex_onnx) + expect(retrievedEngine).toBe(cortexEngine) + }) + + test('should map cortex_tensorrtllm to cortex engine', () => { + const cortexEngine = new MockAIEngine(InferenceEngine.cortex) + // @ts-ignore + engineManager.register(cortexEngine) + + // @ts-ignore + const retrievedEngine = engineManager.get(InferenceEngine.cortex_tensorrtllm) + expect(retrievedEngine).toBe(cortexEngine) + }) + }) + + describe('singleton instance', () => { + test('should return the window.core.engineManager if available', () => { + const mockEngineManager = new EngineManager() + // @ts-ignore + window.core = { engineManager: mockEngineManager } + + const instance = EngineManager.instance() + expect(instance).toBe(mockEngineManager) + + // Clean up + // @ts-ignore + delete window.core + }) + + test('should create a new instance if window.core.engineManager is not available', () => { + // @ts-ignore + delete window.core + + const instance = EngineManager.instance() + expect(instance).toBeInstanceOf(EngineManager) + }) + }) }) diff --git a/extensions/engine-management-extension/package.json b/extensions/engine-management-extension/package.json index 4664a7462..cf774f6a2 100644 --- a/extensions/engine-management-extension/package.json +++ b/extensions/engine-management-extension/package.json @@ -1,7 +1,7 @@ { "name": "@janhq/engine-management-extension", "productName": "Engine Management", - "version": "1.0.1", + "version": "1.0.2", "description": "Manages AI engines and their configurations.", "main": "dist/index.js", "node": "dist/node/index.cjs.js", diff --git a/extensions/engine-management-extension/resources/anthropic.json b/extensions/engine-management-extension/resources/anthropic.json index 771a2c9ff..f8ba74e2b 100644 --- a/extensions/engine-management-extension/resources/anthropic.json +++ b/extensions/engine-management-extension/resources/anthropic.json @@ -15,7 +15,7 @@ }, "transform_resp": { "chat_completions": { - "template": "{% if input_request.stream %} {\"object\": \"chat.completion.chunk\", \"model\": \"{{ input_request.model }}\", \"choices\": [{\"index\": 0, \"delta\": { {% if input_request.type == \"message_start\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"ping\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"content_block_delta\" %} \"role\": \"assistant\", \"content\": \"{{ tojson(input_request.delta.text) }}\" {% else if input_request.type == \"content_block_stop\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"content_block_stop\" %} \"role\": \"assistant\", \"content\": null {% endif %} }, {% if input_request.type == \"content_block_stop\" %} \"finish_reason\": \"stop\" {% else %} \"finish_reason\": null {% endif %} }]} {% else %} {{tojson(input_request)}} {% endif %}" + "template": "{% if input_request.stream %} {\"object\": \"chat.completion.chunk\", \"model\": \"{{ input_request.model }}\", \"choices\": [{\"index\": 0, \"delta\": { {% if input_request.type == \"message_start\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"ping\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"content_block_delta\" %} \"role\": \"assistant\", \"content\": {{ tojson(input_request.delta.text) }} {% else if input_request.type == \"content_block_stop\" %} \"role\": \"assistant\", \"content\": null {% else if input_request.type == \"content_block_stop\" %} \"role\": \"assistant\", \"content\": null {% endif %} }, {% if input_request.type == \"content_block_stop\" %} \"finish_reason\": \"stop\" {% else %} \"finish_reason\": null {% endif %} }]} {% else %} {{tojson(input_request)}} {% endif %}" } }, "explore_models_url": "https://docs.anthropic.com/en/docs/about-claude/models" diff --git a/extensions/engine-management-extension/src/index.ts b/extensions/engine-management-extension/src/index.ts index 3b11db020..1a5b004f7 100644 --- a/extensions/engine-management-extension/src/index.ts +++ b/extensions/engine-management-extension/src/index.ts @@ -286,7 +286,8 @@ export default class JanEngineManagementExtension extends EngineManagementExtens if ( !installedEngines.some( (e) => e.name === variant.variant && e.version === variant.version - ) || variant.version < CORTEX_ENGINE_VERSION + ) || + variant.version < CORTEX_ENGINE_VERSION ) { throw new EngineError( 'Default engine is not available, use bundled version.' @@ -397,7 +398,9 @@ export default class JanEngineManagementExtension extends EngineManagementExtens const { id, ...data } = engine data.api_key = engines[id]?.api_key - return this.updateEngine(data).catch(console.error) + return this.updateEngine(id,{ + ...data, + }).catch(console.error) }) ) await this.updateSettings([ diff --git a/extensions/inference-cortex-extension/src/index.test.ts b/extensions/inference-cortex-extension/src/index.test.ts new file mode 100644 index 000000000..9726400e7 --- /dev/null +++ b/extensions/inference-cortex-extension/src/index.test.ts @@ -0,0 +1,452 @@ +import { describe, beforeEach, it, expect, vi, afterEach } from 'vitest' + +// Must mock before imports +vi.mock('@janhq/core', () => { + return { + executeOnMain: vi.fn().mockResolvedValue({}), + events: { + emit: vi.fn() + }, + extractModelLoadParams: vi.fn().mockReturnValue({}), + ModelEvent: { + OnModelsUpdate: 'OnModelsUpdate', + OnModelStopped: 'OnModelStopped' + }, + EngineEvent: { + OnEngineUpdate: 'OnEngineUpdate' + }, + InferenceEngine: { + cortex: 'cortex', + nitro: 'nitro', + cortex_llamacpp: 'cortex_llamacpp' + }, + LocalOAIEngine: class LocalOAIEngine { + onLoad() {} + onUnload() {} + } + } +}) + +import JanInferenceCortexExtension, { Settings } from './index' +import { InferenceEngine, ModelEvent, EngineEvent, executeOnMain, events } from '@janhq/core' +import ky from 'ky' + +// Mock global variables +const CORTEX_API_URL = 'http://localhost:3000' +const CORTEX_SOCKET_URL = 'ws://localhost:3000' +const SETTINGS = [ + { id: 'n_parallel', name: 'Parallel Execution', description: 'Number of parallel executions', type: 'number', value: '4' }, + { id: 'cont_batching', name: 'Continuous Batching', description: 'Enable continuous batching', type: 'boolean', value: true }, + { id: 'caching_enabled', name: 'Caching', description: 'Enable caching', type: 'boolean', value: true }, + { id: 'flash_attn', name: 'Flash Attention', description: 'Enable flash attention', type: 'boolean', value: true }, + { id: 'cache_type', name: 'Cache Type', description: 'Type of cache to use', type: 'string', value: 'f16' }, + { id: 'use_mmap', name: 'Use Memory Map', description: 'Use memory mapping', type: 'boolean', value: true }, + { id: 'cpu_threads', name: 'CPU Threads', description: 'Number of CPU threads', type: 'number', value: '' } +] +const NODE = 'node' + +// Mock globals +vi.stubGlobal('CORTEX_API_URL', CORTEX_API_URL) +vi.stubGlobal('CORTEX_SOCKET_URL', CORTEX_SOCKET_URL) +vi.stubGlobal('SETTINGS', SETTINGS) +vi.stubGlobal('NODE', NODE) +vi.stubGlobal('window', { + addEventListener: vi.fn() +}) + +// Mock WebSocket +class MockWebSocket { + url :string + listeners: {} + onclose: Function + + constructor(url) { + this.url = url + this.listeners = {} + } + + addEventListener(event, listener) { + this.listeners[event] = listener + } + + emit(event, data) { + if (this.listeners[event]) { + this.listeners[event](data) + } + } + + close() { + if (this.onclose) { + this.onclose({ code: 1000 }) + } + } +} + +// Mock global WebSocket +// @ts-ignore +global.WebSocket = vi.fn().mockImplementation((url) => new MockWebSocket(url)) + +describe('JanInferenceCortexExtension', () => { + let extension + + beforeEach(() => { + // Reset mocks + vi.clearAllMocks() + + // Create a new instance for each test + extension = new JanInferenceCortexExtension() + + // Mock the getSetting method + extension.getSetting = vi.fn().mockImplementation((key, defaultValue) => { + switch(key) { + case Settings.n_parallel: + return '4' + case Settings.cont_batching: + return true + case Settings.caching_enabled: + return true + case Settings.flash_attn: + return true + case Settings.cache_type: + return 'f16' + case Settings.use_mmap: + return true + case Settings.cpu_threads: + return '' + default: + return defaultValue + } + }) + + // Mock methods + extension.registerSettings = vi.fn() + extension.onLoad = vi.fn() + extension.clean = vi.fn().mockResolvedValue({}) + extension.healthz = vi.fn().mockResolvedValue({}) + extension.subscribeToEvents = vi.fn() + }) + + describe('onSettingUpdate', () => { + it('should update n_parallel setting correctly', () => { + extension.onSettingUpdate(Settings.n_parallel, '8') + expect(extension.n_parallel).toBe(8) + }) + + it('should update cont_batching setting correctly', () => { + extension.onSettingUpdate(Settings.cont_batching, false) + expect(extension.cont_batching).toBe(false) + }) + + it('should update caching_enabled setting correctly', () => { + extension.onSettingUpdate(Settings.caching_enabled, false) + expect(extension.caching_enabled).toBe(false) + }) + + it('should update flash_attn setting correctly', () => { + extension.onSettingUpdate(Settings.flash_attn, false) + expect(extension.flash_attn).toBe(false) + }) + + it('should update cache_type setting correctly', () => { + extension.onSettingUpdate(Settings.cache_type, 'f32') + expect(extension.cache_type).toBe('f32') + }) + + it('should update use_mmap setting correctly', () => { + extension.onSettingUpdate(Settings.use_mmap, false) + expect(extension.use_mmap).toBe(false) + }) + + it('should update cpu_threads setting correctly', () => { + extension.onSettingUpdate(Settings.cpu_threads, '4') + expect(extension.cpu_threads).toBe(4) + }) + + it('should not update cpu_threads when value is not a number', () => { + extension.cpu_threads = undefined + extension.onSettingUpdate(Settings.cpu_threads, 'not-a-number') + expect(extension.cpu_threads).toBeUndefined() + }) + }) + + describe('onUnload', () => { + it('should clean up resources correctly', async () => { + extension.shouldReconnect = true + + await extension.onUnload() + + expect(extension.shouldReconnect).toBe(false) + expect(extension.clean).toHaveBeenCalled() + expect(executeOnMain).toHaveBeenCalledWith(NODE, 'dispose') + }) + }) + + describe('loadModel', () => { + it('should remove llama_model_path and mmproj from settings', async () => { + // Setup + const model = { + id: 'test-model', + settings: { + llama_model_path: '/path/to/model', + mmproj: '/path/to/mmproj', + some_setting: 'value' + }, + engine: InferenceEngine.cortex_llamacpp + } + + // Mock ky.post + vi.spyOn(ky, 'post').mockImplementation(() => ({ + // @ts-ignore + json: () => Promise.resolve({}), + catch: () => ({ + finally: () => ({ + // @ts-ignore + then: () => Promise.resolve({}) + }) + }) + })) + + // Setup queue for testing + extension.queue = { add: vi.fn(fn => fn()) } + + // Execute + await extension.loadModel(model) + + // Verify settings were filtered + expect(model.settings).not.toHaveProperty('llama_model_path') + expect(model.settings).not.toHaveProperty('mmproj') + expect(model.settings).toHaveProperty('some_setting') + }) + + it('should convert nitro to cortex_llamacpp engine', async () => { + // Setup + const model = { + id: 'test-model', + settings: {}, + engine: InferenceEngine.nitro + } + + // Mock ky.post + const mockKyPost = vi.spyOn(ky, 'post').mockImplementation(() => ({ + // @ts-ignore + json: () => Promise.resolve({}), + catch: () => ({ + finally: () => ({ + // @ts-ignore + then: () => Promise.resolve({}) + }) + }) + })) + + // Setup queue for testing + extension.queue = { add: vi.fn(fn => fn()) } + + // Execute + await extension.loadModel(model) + + // Verify API call + expect(mockKyPost).toHaveBeenCalledWith( + `${CORTEX_API_URL}/v1/models/start`, + expect.objectContaining({ + json: expect.objectContaining({ + engine: InferenceEngine.cortex_llamacpp + }) + }) + ) + }) + }) + + describe('unloadModel', () => { + it('should call the correct API endpoint and abort loading if in progress', async () => { + // Setup + const model = { id: 'test-model' } + const mockAbort = vi.fn() + extension.abortControllers.set(model.id, { abort: mockAbort }) + + // Mock ky.post + const mockKyPost = vi.spyOn(ky, 'post').mockImplementation(() => ({ + // @ts-ignore + json: () => Promise.resolve({}), + finally: () => ({ + // @ts-ignore + then: () => Promise.resolve({}) + }) + })) + + // Execute + await extension.unloadModel(model) + + // Verify API call + expect(mockKyPost).toHaveBeenCalledWith( + `${CORTEX_API_URL}/v1/models/stop`, + expect.objectContaining({ + json: { model: model.id } + }) + ) + + // Verify abort controller was called + expect(mockAbort).toHaveBeenCalled() + }) + }) + + describe('clean', () => { + it('should make a DELETE request to destroy process manager', async () => { + // Mock the ky.delete function directly + const mockDelete = vi.fn().mockReturnValue({ + catch: vi.fn().mockReturnValue(Promise.resolve({})) + }) + + // Replace the original implementation + vi.spyOn(ky, 'delete').mockImplementation(mockDelete) + + // Override the clean method to use the real implementation + // @ts-ignore + extension.clean = JanInferenceCortexExtension.prototype.clean + + // Call the method + await extension.clean() + + // Verify the correct API call was made + expect(mockDelete).toHaveBeenCalledWith( + `${CORTEX_API_URL}/processmanager/destroy`, + expect.objectContaining({ + timeout: 2000, + retry: expect.objectContaining({ + limit: 0 + }) + }) + ) + }) + }) + + describe('WebSocket events', () => { + it('should handle WebSocket events correctly', () => { + // Create a mock implementation for subscribeToEvents that stores the socket + let messageHandler; + let closeHandler; + + // Override the private method + extension.subscribeToEvents = function() { + this.socket = new MockWebSocket('ws://localhost:3000/events'); + this.socket.addEventListener('message', (event) => { + const data = JSON.parse(event.data); + + // Store for testing + messageHandler = data; + + const transferred = data.task.items.reduce( + (acc, cur) => acc + cur.downloadedBytes, + 0 + ); + const total = data.task.items.reduce( + (acc, cur) => acc + cur.bytes, + 0 + ); + const percent = total > 0 ? transferred / total : 0; + + events.emit( + data.type === 'DownloadUpdated' ? 'onFileDownloadUpdate' : + data.type === 'DownloadSuccess' ? 'onFileDownloadSuccess' : + data.type, + { + modelId: data.task.id, + percent: percent, + size: { + transferred: transferred, + total: total, + }, + downloadType: data.task.type, + } + ); + + if (data.task.type === 'Engine') { + events.emit(EngineEvent.OnEngineUpdate, { + type: data.type, + percent: percent, + id: data.task.id, + }); + } + else if (data.type === 'DownloadSuccess') { + setTimeout(() => { + events.emit(ModelEvent.OnModelsUpdate, { + fetch: true, + }); + }, 500); + } + }); + + this.socket.onclose = (event) => { + closeHandler = event; + // Notify app to update model running state + events.emit(ModelEvent.OnModelStopped, {}); + }; + }; + + // Setup queue + extension.queue = { + add: vi.fn(fn => fn()) + }; + + // Execute the method + extension.subscribeToEvents(); + + // Simulate a message event + extension.socket.listeners.message({ + data: JSON.stringify({ + type: 'DownloadUpdated', + task: { + id: 'test-model', + type: 'Model', + items: [ + { downloadedBytes: 50, bytes: 100 } + ] + } + }) + }); + + // Verify event emission + expect(events.emit).toHaveBeenCalledWith( + 'onFileDownloadUpdate', + expect.objectContaining({ + modelId: 'test-model', + percent: 0.5 + }) + ); + + // Simulate a download success event + vi.useFakeTimers(); + extension.socket.listeners.message({ + data: JSON.stringify({ + type: 'DownloadSuccess', + task: { + id: 'test-model', + type: 'Model', + items: [ + { downloadedBytes: 100, bytes: 100 } + ] + } + }) + }); + + // Fast-forward time to trigger the timeout + vi.advanceTimersByTime(500); + + // Verify the ModelEvent.OnModelsUpdate event was emitted + expect(events.emit).toHaveBeenCalledWith( + ModelEvent.OnModelsUpdate, + { fetch: true } + ); + + vi.useRealTimers(); + + // Trigger websocket close + extension.socket.onclose({ code: 1000 }); + + // Verify OnModelStopped event was emitted + expect(events.emit).toHaveBeenCalledWith( + ModelEvent.OnModelStopped, + {} + ); + }); + }) +}) \ No newline at end of file