refactor load/unload again; move types to core and refactor AIEngine abstract class
This commit is contained in:
parent
ee2cb9e625
commit
a7a2dcc8d8
@ -1,15 +1,208 @@
|
||||
import { events } from '../../events'
|
||||
import { BaseExtension } from '../../extension'
|
||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
||||
import { EngineManager } from './EngineManager'
|
||||
|
||||
/* AIEngine class types */
|
||||
|
||||
export interface chatCompletionRequestMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool'
|
||||
content: string | null
|
||||
name?: string
|
||||
tool_calls?: any[] // Simplified
|
||||
tool_call_id?: string
|
||||
}
|
||||
|
||||
export interface chatCompletionRequest {
|
||||
provider: string,
|
||||
model: string // Model ID, though for local it might be implicit via sessionId
|
||||
messages: chatCompletionRequestMessage[]
|
||||
temperature?: number | null
|
||||
top_p?: number | null
|
||||
n?: number | null
|
||||
stream?: boolean | null
|
||||
stop?: string | string[] | null
|
||||
max_tokens?: number
|
||||
presence_penalty?: number | null
|
||||
frequency_penalty?: number | null
|
||||
logit_bias?: { [key: string]: number } | null
|
||||
user?: string
|
||||
// ... TODO: other OpenAI params
|
||||
}
|
||||
|
||||
export interface chatCompletionChunkChoiceDelta {
|
||||
content?: string | null
|
||||
role?: 'system' | 'user' | 'assistant' | 'tool'
|
||||
tool_calls?: any[] // Simplified
|
||||
}
|
||||
|
||||
export interface chatCompletionChunkChoice {
|
||||
index: number
|
||||
delta: chatCompletionChunkChoiceDelta
|
||||
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null
|
||||
}
|
||||
|
||||
export interface chatCompletionChunk {
|
||||
id: string
|
||||
object: 'chat.completion.chunk'
|
||||
created: number
|
||||
model: string
|
||||
choices: chatCompletionChunkChoice[]
|
||||
system_fingerprint?: string
|
||||
}
|
||||
|
||||
export interface chatCompletionChoice {
|
||||
index: number
|
||||
message: chatCompletionRequestMessage // Response message
|
||||
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'
|
||||
logprobs?: any // Simplified
|
||||
}
|
||||
|
||||
export interface chatCompletion {
|
||||
id: string
|
||||
object: 'chat.completion'
|
||||
created: number
|
||||
model: string // Model ID used
|
||||
choices: chatCompletionChoice[]
|
||||
usage?: {
|
||||
prompt_tokens: number
|
||||
completion_tokens: number
|
||||
total_tokens: number
|
||||
}
|
||||
system_fingerprint?: string
|
||||
}
|
||||
// --- End OpenAI types ---
|
||||
|
||||
// Shared model metadata
|
||||
export interface modelInfo {
|
||||
id: string // e.g. "qwen3-4B" or "org/model/quant"
|
||||
name: string // human‑readable, e.g., "Qwen3 4B Q4_0"
|
||||
quant_type?: string // q4_0 (optional as it might be part of ID or name)
|
||||
providerId: string // e.g. "llama.cpp"
|
||||
port: number
|
||||
sizeBytes: number
|
||||
tags?: string[]
|
||||
path?: string // Absolute path to the model file, if applicable
|
||||
// Additional provider-specific metadata can be added here
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 1. /list
|
||||
export interface listOptions {
|
||||
providerId: string // To specify which provider if a central manager calls this
|
||||
}
|
||||
export type listResult = modelInfo[]
|
||||
|
||||
// 2. /pull
|
||||
export interface pullOptions {
|
||||
providerId: string
|
||||
modelId: string // Identifier for the model to pull (e.g., from a known registry)
|
||||
downloadUrl: string // URL to download the model from
|
||||
/** optional callback to receive download progress */
|
||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void
|
||||
}
|
||||
export interface pullResult {
|
||||
success: boolean
|
||||
path?: string // local file path to the pulled model
|
||||
error?: string
|
||||
modelInfo?: modelInfo // Info of the pulled model
|
||||
}
|
||||
|
||||
// 3. /load
|
||||
export interface loadOptions {
|
||||
modelPath: string
|
||||
port?: number
|
||||
n_gpu_layers?: number
|
||||
n_ctx?: number
|
||||
threads?: number
|
||||
threads_batch?: number
|
||||
ctx_size?: number
|
||||
n_predict?: number
|
||||
batch_size?: number
|
||||
ubatch_size?: number
|
||||
device?: string
|
||||
split_mode?: string
|
||||
main_gpu?: number
|
||||
flash_attn?: boolean
|
||||
cont_batching?: boolean
|
||||
no_mmap?: boolean
|
||||
mlock?: boolean
|
||||
no_kv_offload?: boolean
|
||||
cache_type_k?: string
|
||||
cache_type_v?: string
|
||||
defrag_thold?: number
|
||||
rope_scaling?: string
|
||||
rope_scale?: number
|
||||
rope_freq_base?: number
|
||||
rope_freq_scale?: number
|
||||
}
|
||||
|
||||
export interface sessionInfo {
|
||||
sessionId: string // opaque handle for unload/chat
|
||||
port: number // llama-server output port (corrected from portid)
|
||||
modelName: string, //name of the model
|
||||
modelPath: string // path of the loaded model
|
||||
}
|
||||
|
||||
// 4. /unload
|
||||
export interface unloadOptions {
|
||||
providerId: string
|
||||
sessionId: string
|
||||
}
|
||||
export interface unloadResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
// 5. /chat
|
||||
export interface chatOptions {
|
||||
providerId: string
|
||||
sessionId: string
|
||||
/** Full OpenAI ChatCompletionRequest payload */
|
||||
payload: chatCompletionRequest
|
||||
}
|
||||
// Output for /chat will be Promise<ChatCompletion> for non-streaming
|
||||
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
|
||||
|
||||
// 6. /delete
|
||||
export interface deleteOptions {
|
||||
providerId: string
|
||||
modelId: string // The ID of the model to delete (implies finding its path)
|
||||
modelPath?: string // Optionally, direct path can be provided
|
||||
}
|
||||
export interface deleteResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
// 7. /import
|
||||
export interface importOptions {
|
||||
providerId: string
|
||||
sourcePath: string // Path to the local model file to import
|
||||
desiredModelId?: string // Optional: if user wants to name it specifically
|
||||
}
|
||||
export interface importResult {
|
||||
success: boolean
|
||||
modelInfo?: modelInfo
|
||||
error?: string
|
||||
}
|
||||
|
||||
// 8. /abortPull
|
||||
export interface abortPullOptions {
|
||||
providerId: string
|
||||
modelId: string // The modelId whose download is to be aborted
|
||||
}
|
||||
export interface abortPullResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Base AIEngine
|
||||
* Applicable to all AI Engines
|
||||
*/
|
||||
|
||||
export abstract class AIEngine extends BaseExtension {
|
||||
// The inference engine
|
||||
abstract provider: string
|
||||
// The inference engine ID, implementing the readonly providerId from interface
|
||||
abstract readonly provider: string
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
@ -24,4 +217,49 @@ export abstract class AIEngine extends BaseExtension {
|
||||
registerEngine() {
|
||||
EngineManager.instance().register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists available models
|
||||
*/
|
||||
abstract list(opts: listOptions): Promise<listResult>
|
||||
|
||||
/**
|
||||
* Pulls/downloads a model
|
||||
*/
|
||||
abstract pull(opts: pullOptions): Promise<pullResult>
|
||||
|
||||
/**
|
||||
* Loads a model into memory
|
||||
*/
|
||||
abstract load(opts: loadOptions): Promise<sessionInfo>
|
||||
|
||||
/**
|
||||
* Unloads a model from memory
|
||||
*/
|
||||
abstract unload(opts: unloadOptions): Promise<unloadResult>
|
||||
|
||||
/**
|
||||
* Sends a chat request to the model
|
||||
*/
|
||||
abstract chat(opts: chatCompletionRequest): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>
|
||||
|
||||
/**
|
||||
* Deletes a model
|
||||
*/
|
||||
abstract delete(opts: deleteOptions): Promise<deleteResult>
|
||||
|
||||
/**
|
||||
* Imports a model
|
||||
*/
|
||||
abstract import(opts: importOptions): Promise<importResult>
|
||||
|
||||
/**
|
||||
* Aborts an ongoing model pull
|
||||
*/
|
||||
abstract abortPull(opts: abortPullOptions): Promise<abortPullResult>
|
||||
|
||||
/**
|
||||
* Optional method to get the underlying chat client
|
||||
*/
|
||||
getChatClient?(sessionId: string): any
|
||||
}
|
||||
|
||||
@ -30,8 +30,7 @@
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
"package.json"
|
||||
],
|
||||
"bundleDependencies": [
|
||||
"fetch-retry"
|
||||
|
||||
@ -6,11 +6,11 @@
|
||||
* @module llamacpp-extension/src/index
|
||||
*/
|
||||
|
||||
import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core'
|
||||
|
||||
import { invoke } from '@tauri-apps/api/tauri'
|
||||
import {
|
||||
localProvider,
|
||||
AIEngine,
|
||||
getJanDataFolderPath,
|
||||
fs,
|
||||
joinPath,
|
||||
modelInfo,
|
||||
listOptions,
|
||||
listResult,
|
||||
@ -30,7 +30,9 @@ import {
|
||||
abortPullOptions,
|
||||
abortPullResult,
|
||||
chatCompletionRequest,
|
||||
} from './types'
|
||||
} from '@janhq/core'
|
||||
|
||||
import { invoke } from '@tauri-apps/api/tauri'
|
||||
|
||||
/**
|
||||
* Helper to convert GGUF model filename to a more structured ID/name
|
||||
@ -56,10 +58,7 @@ function parseGGUFFileName(filename: string): {
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
export default class llamacpp_extension
|
||||
extends AIEngine
|
||||
implements localProvider
|
||||
{
|
||||
export default class llamacpp_extension extends AIEngine {
|
||||
provider: string = 'llamacpp'
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
@ -79,17 +78,20 @@ export default class llamacpp_extension
|
||||
}
|
||||
|
||||
// Implement the required LocalProvider interface methods
|
||||
async list(opts: listOptions): Promise<listResult> {
|
||||
override async list(opts: listOptions): Promise<listResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async pull(opts: pullOptions): Promise<pullResult> {
|
||||
override async pull(opts: pullOptions): Promise<pullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async load(opts: loadOptions): Promise<sessionInfo> {
|
||||
override async load(opts: loadOptions): Promise<sessionInfo> {
|
||||
const args: string[] = []
|
||||
|
||||
// disable llama-server webui
|
||||
args.push('--no-webui')
|
||||
|
||||
// model option is required
|
||||
args.push('-m', opts.modelPath)
|
||||
args.push('--port', String(opts.port || 8080)) // Default port if not specified
|
||||
@ -193,24 +195,24 @@ export default class llamacpp_extension
|
||||
console.log('Calling Tauri command load with args:', args)
|
||||
|
||||
try {
|
||||
const sessionInfo = await invoke<sessionInfo>('plugin:llamacpp|load', {
|
||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||
args: args,
|
||||
})
|
||||
|
||||
// Store the session info for later use
|
||||
this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
|
||||
this.activeSessions.set(sInfo.sessionId, sInfo)
|
||||
|
||||
return sessionInfo
|
||||
return sInfo
|
||||
} catch (error) {
|
||||
console.error('Error loading llama-server:', error)
|
||||
throw new Error(`Failed to load llama-server: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
async unload(opts: unloadOptions): Promise<unloadResult> {
|
||||
override async unload(opts: unloadOptions): Promise<unloadResult> {
|
||||
try {
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<unloadResult>('plugin:llamacpp|unload', {
|
||||
const result = await invoke<unloadResult>('unload_llama_model', {
|
||||
session_id: opts.sessionId, // Using PID as session ID
|
||||
})
|
||||
|
||||
@ -232,27 +234,125 @@ export default class llamacpp_extension
|
||||
}
|
||||
}
|
||||
|
||||
async chat(
|
||||
opts: chatOptions
|
||||
private async *handleStreamingResponse(
|
||||
url: string,
|
||||
headers: HeadersInit,
|
||||
body: string
|
||||
): AsyncIterable<chatCompletionChunk> {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => null)
|
||||
throw new Error(
|
||||
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||||
)
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null')
|
||||
}
|
||||
|
||||
const reader = response.body.getReader()
|
||||
const decoder = new TextDecoder('utf-8')
|
||||
let buffer = ''
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
// Process complete lines in the buffer
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // Keep the last incomplete line in the buffer
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmedLine = line.trim()
|
||||
if (!trimmedLine || trimmedLine === 'data: [DONE]') {
|
||||
continue
|
||||
}
|
||||
|
||||
if (trimmedLine.startsWith('data: ')) {
|
||||
const jsonStr = trimmedLine.slice(6)
|
||||
try {
|
||||
const chunk = JSON.parse(jsonStr) as chatCompletionChunk
|
||||
yield chunk
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON from stream:', e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
private findSessionByModel(modelName: string): sessionInfo | undefined {
|
||||
for (const [, session] of this.activeSessions) {
|
||||
if (session.modelName === modelName) {
|
||||
return session
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
override async chat(
|
||||
opts: chatCompletionRequest
|
||||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||
throw new Error("method not implemented yet")
|
||||
const sessionInfo = this.findSessionByModel(opts.model)
|
||||
if (!sessionInfo) {
|
||||
throw new Error(`No active session found for model: ${opts.model}`)
|
||||
}
|
||||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||||
const url = `${baseUrl}/chat/completions`
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer test-k`,
|
||||
}
|
||||
|
||||
const body = JSON.stringify(opts)
|
||||
if (opts.stream) {
|
||||
return this.handleStreamingResponse(url, headers, body)
|
||||
}
|
||||
// Handle non-streaming response
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => null)
|
||||
throw new Error(
|
||||
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||||
)
|
||||
}
|
||||
|
||||
return (await response.json()) as chatCompletion
|
||||
}
|
||||
|
||||
async delete(opts: deleteOptions): Promise<deleteResult> {
|
||||
throw new Error("method not implemented yet")
|
||||
override async delete(opts: deleteOptions): Promise<deleteResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async import(opts: importOptions): Promise<importResult> {
|
||||
throw new Error("method not implemented yet")
|
||||
override async import(opts: importOptions): Promise<importResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
||||
override async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
// Optional method for direct client access
|
||||
getChatClient(sessionId: string): any {
|
||||
throw new Error("method not implemented yet")
|
||||
override getChatClient(sessionId: string): any {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
onUnload(): void {
|
||||
|
||||
@ -1,214 +0,0 @@
|
||||
// src/providers/local/types.ts
|
||||
|
||||
// --- Re-using OpenAI types (minimal definitions for this example) ---
|
||||
// In a real project, you'd import these from 'openai' or a shared types package.
|
||||
export interface chatCompletionRequestMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
content: string | null;
|
||||
name?: string;
|
||||
tool_calls?: any[]; // Simplified
|
||||
tool_call_id?: string;
|
||||
}
|
||||
|
||||
export interface chatCompletionRequest {
|
||||
model: string; // Model ID, though for local it might be implicit via sessionId
|
||||
messages: chatCompletionRequestMessage[];
|
||||
temperature?: number | null;
|
||||
top_p?: number | null;
|
||||
n?: number | null;
|
||||
stream?: boolean | null;
|
||||
stop?: string | string[] | null;
|
||||
max_tokens?: number;
|
||||
presence_penalty?: number | null;
|
||||
frequency_penalty?: number | null;
|
||||
logit_bias?: Record<string, number> | null;
|
||||
user?: string;
|
||||
// ... TODO: other OpenAI params
|
||||
}
|
||||
|
||||
export interface chatCompletionChunkChoiceDelta {
|
||||
content?: string | null;
|
||||
role?: 'system' | 'user' | 'assistant' | 'tool';
|
||||
tool_calls?: any[]; // Simplified
|
||||
}
|
||||
|
||||
export interface chatCompletionChunkChoice {
|
||||
index: number;
|
||||
delta: chatCompletionChunkChoiceDelta;
|
||||
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
|
||||
}
|
||||
|
||||
export interface chatCompletionChunk {
|
||||
id: string;
|
||||
object: 'chat.completion.chunk';
|
||||
created: number;
|
||||
model: string;
|
||||
choices: chatCompletionChunkChoice[];
|
||||
system_fingerprint?: string;
|
||||
}
|
||||
|
||||
|
||||
export interface chatCompletionChoice {
|
||||
index: number;
|
||||
message: chatCompletionRequestMessage; // Response message
|
||||
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
|
||||
logprobs?: any; // Simplified
|
||||
}
|
||||
|
||||
export interface chatCompletion {
|
||||
id: string;
|
||||
object: 'chat.completion';
|
||||
created: number;
|
||||
model: string; // Model ID used
|
||||
choices: chatCompletionChoice[];
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
system_fingerprint?: string;
|
||||
}
|
||||
// --- End OpenAI types ---
|
||||
|
||||
|
||||
// Shared model metadata
|
||||
export interface modelInfo {
|
||||
id: string; // e.g. "qwen3-4B" or "org/model/quant"
|
||||
name: string; // human‑readable, e.g., "Qwen3 4B Q4_0"
|
||||
quant_type?: string; // q4_0 (optional as it might be part of ID or name)
|
||||
providerId: string; // e.g. "llama.cpp"
|
||||
port: number;
|
||||
sizeBytes: number;
|
||||
tags?: string[];
|
||||
path?: string; // Absolute path to the model file, if applicable
|
||||
// Additional provider-specific metadata can be added here
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
// 1. /list
|
||||
export interface listOptions {
|
||||
providerId: string; // To specify which provider if a central manager calls this
|
||||
}
|
||||
export type listResult = modelInfo[];
|
||||
|
||||
// 2. /pull
|
||||
export interface pullOptions {
|
||||
providerId: string;
|
||||
modelId: string; // Identifier for the model to pull (e.g., from a known registry)
|
||||
downloadUrl: string; // URL to download the model from
|
||||
/** optional callback to receive download progress */
|
||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
|
||||
}
|
||||
export interface pullResult {
|
||||
success: boolean;
|
||||
path?: string; // local file path to the pulled model
|
||||
error?: string;
|
||||
modelInfo?: modelInfo; // Info of the pulled model
|
||||
}
|
||||
|
||||
// 3. /load
|
||||
export interface loadOptions {
|
||||
modelPath: string
|
||||
port?: number
|
||||
n_gpu_layers?: number
|
||||
n_ctx?: number
|
||||
threads?: number
|
||||
threads_batch?: number
|
||||
ctx_size?: number
|
||||
n_predict?: number
|
||||
batch_size?: number
|
||||
ubatch_size?: number
|
||||
device?: string
|
||||
split_mode?: string
|
||||
main_gpu?: number
|
||||
flash_attn?: boolean
|
||||
cont_batching?: boolean
|
||||
no_mmap?: boolean
|
||||
mlock?: boolean
|
||||
no_kv_offload?: boolean
|
||||
cache_type_k?: string
|
||||
cache_type_v?: string
|
||||
defrag_thold?: number
|
||||
rope_scaling?: string
|
||||
rope_scale?: number
|
||||
rope_freq_base?: number
|
||||
rope_freq_scale?: number
|
||||
}
|
||||
|
||||
export interface sessionInfo {
|
||||
sessionId: string; // opaque handle for unload/chat
|
||||
port: number; // llama-server output port (corrected from portid)
|
||||
modelPath: string; // path of the loaded model
|
||||
settings: Record<string, unknown>; // The actual settings used to load
|
||||
}
|
||||
|
||||
// 4. /unload
|
||||
export interface unloadOptions {
|
||||
providerId: string;
|
||||
sessionId: string;
|
||||
}
|
||||
export interface unloadResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 5. /chat
|
||||
export interface chatOptions {
|
||||
providerId: string;
|
||||
sessionId: string;
|
||||
/** Full OpenAI ChatCompletionRequest payload */
|
||||
payload: chatCompletionRequest;
|
||||
}
|
||||
// Output for /chat will be Promise<ChatCompletion> for non-streaming
|
||||
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
|
||||
|
||||
// 6. /delete
|
||||
export interface deleteOptions {
|
||||
providerId: string;
|
||||
modelId: string; // The ID of the model to delete (implies finding its path)
|
||||
modelPath?: string; // Optionally, direct path can be provided
|
||||
}
|
||||
export interface deleteResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 7. /import
|
||||
export interface importOptions {
|
||||
providerId: string;
|
||||
sourcePath: string; // Path to the local model file to import
|
||||
desiredModelId?: string; // Optional: if user wants to name it specifically
|
||||
}
|
||||
export interface importResult {
|
||||
success: boolean;
|
||||
modelInfo?: modelInfo;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 8. /abortPull
|
||||
export interface abortPullOptions {
|
||||
providerId: string;
|
||||
modelId: string; // The modelId whose download is to be aborted
|
||||
}
|
||||
export interface abortPullResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
|
||||
// The interface for any local provider
|
||||
export interface localProvider {
|
||||
readonly providerId: string;
|
||||
|
||||
list(opts: listOptions): Promise<listResult>;
|
||||
pull(opts: pullOptions): Promise<pullResult>;
|
||||
load(opts: loadOptions): Promise<sessionInfo>;
|
||||
unload(opts: unloadOptions): Promise<unloadResult>;
|
||||
chat(opts: chatOptions): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>;
|
||||
delete(opts: deleteOptions): Promise<deleteResult>;
|
||||
import(opts: importOptions): Promise<importResult>;
|
||||
abortPull(opts: abortPullOptions): Promise<abortPullResult>;
|
||||
|
||||
// Optional: for direct access to underlying client if needed for specific streaming cases
|
||||
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session
|
||||
}
|
||||
@ -79,7 +79,7 @@ pub struct UnloadResult {
|
||||
|
||||
// --- Load Command ---
|
||||
#[tauri::command]
|
||||
pub async fn load(
|
||||
pub async fn load_llama_model(
|
||||
app_handle: AppHandle, // Get the AppHandle
|
||||
state: State<'_, AppState>, // Access the shared state
|
||||
args: Vec<String>, // Arguments from the frontend
|
||||
@ -143,7 +143,7 @@ pub async fn load(
|
||||
|
||||
// --- Unload Command ---
|
||||
#[tauri::command]
|
||||
pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
|
||||
pub async fn unload_llama_model(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
|
||||
let mut process_lock = state.llama_server_process.lock().await;
|
||||
// Take the child process out of the Option, leaving None in its place
|
||||
if let Some(mut child) = process_lock.take() {
|
||||
|
||||
@ -86,8 +86,8 @@ pub fn run() {
|
||||
core::hardware::get_system_info,
|
||||
core::hardware::get_system_usage,
|
||||
// llama-cpp extension
|
||||
core::utils::extensions::inference_llamacpp_extension::server::load,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::unload,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
|
||||
])
|
||||
.manage(AppState {
|
||||
app_token: Some(generate_app_token()),
|
||||
|
||||
@ -211,7 +211,7 @@ export const stopModel = async (
|
||||
): Promise<void> => {
|
||||
const providerObj = EngineManager.instance().get(normalizeProvider(provider))
|
||||
const modelObj = ModelManager.instance().get(model)
|
||||
if (providerObj && modelObj) return providerObj?.unloadModel(modelObj)
|
||||
if (providerObj && modelObj) return providerObj?.unload(modelObj)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user