feat: add getTokensCount method to compute token usage (#6467)

* feat: add getTokensCount method to compute token usage

Implemented a new async `getTokensCount` function in the LLaMA.cpp extension.
The method validates the model session, checks process health, applies the request template, and tokenizes the resulting prompt to return the token count. Includes detailed error handling for crashed models and API failures, enabling callers to assess token usage before sending completions.

* Fix: typos

* chore: update ui token usage

* chore: remove unused code

* feat: add image token handling for multimodal LlamaCPP models

Implemented support for counting image tokens when using vision-enabled models:
- Extended `SessionInfo` with optional `mmprojPath` to store the multimodal project file.
- Propagated `mmproj_path` from the Tauri plugin into the session info.
- Added import of `chatCompletionRequestMessage` and enhanced token calculation logic in the LlamaCPP extension:
- Detects image content in messages.
- Reads GGUF metadata from `mmprojPath` to compute accurate image token counts.
- Provides a fallback estimation if metadata reading fails.
- Returns the sum of text and image tokens.
- Introduced helper methods `calculateImageTokens` and `estimateImageTokensFallback`.
- Minor clean‑ups such as comment capitalization and debug logging.

* chore: update FE send params message include content type image_url

* fix mmproj path from session info and num tokens calculation

* fix: Correct image token estimation calculation in llamacpp extension

This commit addresses an inaccurate token count for images in the llama.cpp extension.

The previous logic incorrectly calculated the token count based on image patch size and dimensions. This has been replaced with a more precise method that uses the clip.vision.projection_dim value from the model metadata.

Additionally, unnecessary debug logging was removed, and a new log was added to show the mmproj metadata for improved visibility.

* fix per image calc

* fix: crash due to force unwrap

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
This commit is contained in:
Akarshan Biswas 2025-09-23 07:52:19 +05:30 committed by GitHub
parent 05e58cffe8
commit 885da29f28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 904 additions and 39 deletions

View File

@ -13,7 +13,7 @@ export interface chatCompletionRequestMessage {
}
export interface Content {
type: 'text' | 'input_image' | 'input_audio'
type: 'text' | 'image_url' | 'input_audio'
text?: string
image_url?: string
input_audio?: InputAudio
@ -182,6 +182,7 @@ export interface SessionInfo {
model_id: string //name of the model
model_path: string // path of the loaded model
api_key: string
mmproj_path?: string
}
export interface UnloadResult {

View File

@ -21,6 +21,7 @@ import {
events,
AppEvent,
DownloadEvent,
chatCompletionRequestMessage,
} from '@janhq/core'
import { error, info, warn } from '@tauri-apps/plugin-log'
@ -2296,7 +2297,9 @@ export default class llamacpp_extension extends AIEngine {
: Math.floor(maxContextLength)
const mmprojInfo = mmprojPath
? `, mmprojSize=${(mmprojSize / (1024 * 1024)).toFixed(2)}MB, offloadMmproj=${offloadMmproj}`
? `, mmprojSize=${(mmprojSize / (1024 * 1024)).toFixed(
2
)}MB, offloadMmproj=${offloadMmproj}`
: ''
logger.info(
@ -2489,8 +2492,151 @@ export default class llamacpp_extension extends AIEngine {
logger.error('Failed to validate GGUF file:', error)
return {
isValid: false,
error: `Failed to read model metadata: ${error instanceof Error ? error.message : 'Unknown error'}`,
error: `Failed to read model metadata: ${
error instanceof Error ? error.message : 'Unknown error'
}`,
}
}
}
async getTokensCount(opts: chatCompletionRequest): Promise<number> {
const sessionInfo = await this.findSessionByModel(opts.model)
if (!sessionInfo) {
throw new Error(`No active session found for model: ${opts.model}`)
}
// Check if the process is alive
const result = await invoke<boolean>('plugin:llamacpp|is_process_running', {
pid: sessionInfo.pid,
})
if (result) {
try {
await fetch(`http://localhost:${sessionInfo.port}/health`)
} catch (e) {
this.unload(sessionInfo.model_id)
throw new Error('Model appears to have crashed! Please reload!')
}
} else {
throw new Error('Model has crashed! Please reload!')
}
const baseUrl = `http://localhost:${sessionInfo.port}`
const headers = {
'Content-Type': 'application/json',
'Authorization': `Bearer ${sessionInfo.api_key}`,
}
// Count image tokens first
let imageTokens = 0
const hasImages = opts.messages.some(
(msg) =>
Array.isArray(msg.content) &&
msg.content.some((content) => content.type === 'image_url')
)
if (hasImages) {
logger.info('Conversation has images')
try {
// Read mmproj metadata to get vision parameters
logger.info(`MMPROJ PATH: ${sessionInfo.mmproj_path}`)
const metadata = await readGgufMetadata(sessionInfo.mmproj_path)
logger.info(`mmproj metadata: ${JSON.stringify(metadata.metadata)}`)
imageTokens = await this.calculateImageTokens(
opts.messages,
metadata.metadata
)
} catch (error) {
logger.warn('Failed to calculate image tokens:', error)
// Fallback to a rough estimate if metadata reading fails
imageTokens = this.estimateImageTokensFallback(opts.messages)
}
}
// Calculate text tokens
const messages = JSON.stringify({ messages: opts.messages })
let parseResponse = await fetch(`${baseUrl}/apply-template`, {
method: 'POST',
headers: headers,
body: messages,
})
if (!parseResponse.ok) {
const errorData = await parseResponse.json().catch(() => null)
throw new Error(
`API request failed with status ${
parseResponse.status
}: ${JSON.stringify(errorData)}`
)
}
const parsedPrompt = await parseResponse.json()
const response = await fetch(`${baseUrl}/tokenize`, {
method: 'POST',
headers: headers,
body: JSON.stringify({
content: parsedPrompt.prompt,
}),
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`API request failed with status ${response.status}: ${JSON.stringify(
errorData
)}`
)
}
const dataTokens = await response.json()
const textTokens = dataTokens.tokens?.length || 0
return textTokens + imageTokens
}
private async calculateImageTokens(
messages: chatCompletionRequestMessage[],
metadata: Record<string, string>
): Promise<number> {
// Extract vision parameters from metadata
const projectionDim = Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
// Count images in messages
let imageCount = 0
for (const message of messages) {
if (Array.isArray(message.content)) {
imageCount += message.content.filter(
(content) => content.type === 'image_url'
).length
}
}
logger.info(
`Calculated ${projectionDim} tokens per image, ${imageCount} images total`
)
return projectionDim * imageCount - imageCount // remove the lingering <__image__> placeholder token
}
private estimateImageTokensFallback(
messages: chatCompletionRequestMessage[]
): number {
// Fallback estimation if metadata reading fails
const estimatedTokensPerImage = 256 // Gemma's siglip
let imageCount = 0
for (const message of messages) {
if (Array.isArray(message.content)) {
imageCount += message.content.filter(
(content) => content.type === 'image_url'
).length
}
}
logger.warn(
`Fallback estimation: ${estimatedTokensPerImage} tokens per image, ${imageCount} images total`
)
return imageCount * estimatedTokensPerImage - imageCount // remove the lingering <__image__> placeholder token
}
}

View File

@ -12,7 +12,7 @@ use tokio::time::Instant;
use crate::device::{get_devices_from_backend, DeviceInfo};
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path};
use crate::path::{validate_binary_path, validate_mmproj_path, validate_model_path};
use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid,
@ -55,7 +55,20 @@ pub async fn load_llama_model<R: Runtime>(
let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?;
let _mmproj_path_pb = validate_mmproj_path(&mut args)?;
let mmproj_path_pb = validate_mmproj_path(&mut args)?;
let mmproj_path_string = if let Some(ref _mmproj_pb) = mmproj_path_pb {
// Find the actual mmproj path from args after validation/conversion
if let Some(mmproj_index) = args.iter().position(|arg| arg == "--mmproj") {
Some(args[mmproj_index + 1].clone())
} else {
None
}
} else {
None
};
log::info!("MMPROJ Path string: {}", &mmproj_path_string.as_ref().unwrap_or(&"None".to_string()));
let api_key: String;
@ -211,6 +224,7 @@ pub async fn load_llama_model<R: Runtime>(
model_id: model_id,
model_path: model_path_pb.display().to_string(),
api_key: api_key,
mmproj_path: mmproj_path_string,
};
// Insert session info to process_map
@ -265,7 +279,7 @@ pub async fn unload_llama_model<R: Runtime>(
pub async fn get_devices(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>
envs: HashMap<String, String>,
) -> ServerResult<Vec<DeviceInfo>> {
get_devices_from_backend(backend_path, library_path, envs).await
}

View File

@ -11,6 +11,8 @@ pub struct SessionInfo {
pub model_id: String,
pub model_path: String, // path of the loaded model
pub api_key: String,
#[serde(default)]
pub mmproj_path: Option<String>,
}
pub struct LLamaBackendSession {

View File

@ -0,0 +1,283 @@
import { useMemo, useEffect, useState, useRef } from 'react'
import { cn } from '@/lib/utils'
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from '@/components/ui/tooltip'
import { useTokensCount } from '@/hooks/useTokensCount'
import { ThreadMessage } from '@janhq/core'
interface TokenCounterProps {
messages?: ThreadMessage[]
className?: string
compact?: boolean
additionalTokens?: number // For vision tokens or other additions
uploadedFiles?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
}
export const TokenCounter = ({
messages = [],
className,
compact = false,
additionalTokens = 0,
uploadedFiles = [],
}: TokenCounterProps) => {
const { calculateTokens, ...tokenData } = useTokensCount(
messages,
uploadedFiles
)
const [isAnimating, setIsAnimating] = useState(false)
const [prevTokenCount, setPrevTokenCount] = useState(0)
const [isUpdating, setIsUpdating] = useState(false)
const timersRef = useRef<{ update?: NodeJS.Timeout; anim?: NodeJS.Timeout }>(
{}
)
// Manual calculation - trigger on click
const handleCalculateTokens = () => {
calculateTokens()
}
// Handle token count changes with proper debouncing and cleanup
useEffect(() => {
const currentTotal = tokenData.tokenCount + additionalTokens
const timers = timersRef.current
// Clear any existing timers
if (timers.update) clearTimeout(timers.update)
if (timers.anim) clearTimeout(timers.anim)
if (currentTotal !== prevTokenCount) {
setIsUpdating(true)
// Clear updating state after a longer delay for smoother transitions
timers.update = setTimeout(() => {
setIsUpdating(false)
}, 250)
// Only animate for significant changes and avoid animating on initial load
if (prevTokenCount > 0) {
const difference = Math.abs(currentTotal - prevTokenCount)
if (difference > 10) {
// Increased threshold to reduce micro-animations
setIsAnimating(true)
timers.anim = setTimeout(() => {
setIsAnimating(false)
}, 600)
}
}
setPrevTokenCount(currentTotal)
}
// Cleanup function
return () => {
if (timers.update) clearTimeout(timers.update)
if (timers.anim) clearTimeout(timers.anim)
}
}, [tokenData.tokenCount, additionalTokens, prevTokenCount])
const totalTokens = useMemo(() => {
return tokenData.tokenCount + additionalTokens
}, [tokenData.tokenCount, additionalTokens])
// Percentage calculation to match useTokensCount exactly
const adjustedPercentage = useMemo(() => {
if (!tokenData.maxTokens) return undefined
return (totalTokens / tokenData.maxTokens) * 100
}, [totalTokens, tokenData.maxTokens])
// Check if percentage exceeds max (100%)
const isOverLimit = useMemo(() => {
return adjustedPercentage !== undefined && adjustedPercentage > 100
}, [adjustedPercentage])
const formatNumber = (num: number) => {
if (num >= 1000000) return `${(num / 1000000).toFixed(1)}M`
if (num >= 1000) return `${(num / 1000).toFixed(1)}K`
return num.toString()
}
if (compact) {
return (
<TooltipProvider delayDuration={isUpdating ? 1200 : 400}>
<Tooltip>
<TooltipTrigger asChild>
<div
className={cn('relative cursor-pointer', className)}
onClick={handleCalculateTokens}
>
{/* Main compact display */}
<div className="flex items-center gap-2 px-2 py-1 rounded-md bg-main-view border border-main-view-fg/10">
<span
className={cn(
'text-xs font-medium tabular-nums transition-all duration-500 ease-out',
isOverLimit ? 'text-destructive' : 'text-accent',
isAnimating && 'scale-110'
)}
>
{adjustedPercentage?.toFixed(1) || '0.0'}%
</span>
<div className="relative w-4 h-4 flex-shrink-0">
<svg
className="w-4 h-4 transform -rotate-90"
viewBox="0 0 16 16"
>
<circle
cx="8"
cy="8"
r="6"
stroke="currentColor"
strokeWidth="1.5"
fill="none"
className="text-main-view-fg/20"
/>
<circle
cx="8"
cy="8"
r="6"
stroke="currentColor"
strokeWidth="1.5"
fill="none"
strokeDasharray={`${2 * Math.PI * 6}`}
strokeDashoffset={`${2 * Math.PI * 6 * (1 - (adjustedPercentage || 0) / 100)}`}
className={cn(
'transition-all duration-500 ease-out',
isOverLimit ? 'stroke-destructive' : 'stroke-accent'
)}
style={{
transformOrigin: 'center',
}}
/>
</svg>
</div>
</div>
</div>
</TooltipTrigger>
<TooltipContent
side="bottom"
align="center"
sideOffset={5}
showArrow={false}
className="min-w-[240px] max-w-[240px] bg-main-view border border-main-view-fg/10 "
>
{/* Detailed breakdown panel */}
<>
{/* Header with percentage and progress bar */}
<div className="mb-3">
<div className="flex items-center justify-between mb-2">
<span
className={cn(
'text-lg font-semibold tabular-nums',
isOverLimit ? 'text-destructive' : 'text-accent'
)}
>
{adjustedPercentage?.toFixed(1) || '0.0'}%
</span>
<span className="text-sm text-main-view-fg/60 font-mono">
{formatNumber(totalTokens)} /{' '}
{formatNumber(tokenData.maxTokens || 0)}
</span>
</div>
{/* Progress bar */}
<div className="w-full h-2 bg-main-view-fg/10 rounded-full overflow-hidden">
<div
className={cn(
'h-2 rounded-full transition-all duration-500 ease-out',
isOverLimit ? 'bg-destructive' : 'bg-accent'
)}
style={{
width: `${Math.min(adjustedPercentage || 0, 100)}%`,
}}
/>
</div>
</div>
{/* Token breakdown */}
<div className="space-y-2 mb-3">
<div className="flex items-center justify-between text-sm">
<span className="text-main-view-fg/60">Text</span>
<span className="text-main-view-fg font-mono">
{formatNumber(Math.max(0, tokenData.tokenCount))}
</span>
</div>
</div>
{/* Remaining tokens */}
<div className="border-t border-main-view-fg/10 pt-2">
<div className="flex items-center justify-between text-sm">
<span className="text-main-view-fg/60">Remaining</span>
<span className="text-main-view-fg font-semibold font-mono">
{formatNumber(
Math.max(0, (tokenData.maxTokens || 0) - totalTokens)
)}
</span>
</div>
</div>
</>
</TooltipContent>
</Tooltip>
</TooltipProvider>
)
}
// Non-compact: Simple inline display
return (
<div
className={cn(
'flex items-center w-full justify-between gap-2 py-1 text-xs text-main-view-fg/50',
className
)}
>
<div className="space-x-0.5">
<span>Context&nbsp;</span>
<span
className={cn(
'font-mono font-bold transition-all duration-500 ease-out',
isAnimating && 'scale-110'
)}
>
{formatNumber(totalTokens)}
</span>
{tokenData.maxTokens && (
<>
<span>/</span>
<span
className={cn(
'font-mono font-bold transition-all duration-500 ease-out',
isAnimating && 'scale-110'
)}
>
{formatNumber(tokenData.maxTokens)}
</span>
<span
className={cn(
'ml-1 font-mono font-bold transition-all duration-500 ease-out',
isOverLimit ? 'text-destructive' : 'text-accent',
isAnimating && 'scale-110'
)}
>
({adjustedPercentage?.toFixed(1) || '0.0'}%)
</span>
{isOverLimit && (
<span className="text-xs text-main-view-fg/40">
&nbsp;{isOverLimit ? '⚠️ Over limit' : 'Tokens used'}
</span>
)}
</>
)}
</div>
</div>
)
}

View File

@ -35,9 +35,12 @@ function TooltipTrigger({
function TooltipContent({
className,
sideOffset = 0,
showArrow = true,
children,
...props
}: React.ComponentProps<typeof TooltipPrimitive.Content>) {
}: React.ComponentProps<typeof TooltipPrimitive.Content> & {
showArrow?: boolean
}) {
return (
<TooltipPrimitive.Portal>
<TooltipPrimitive.Content
@ -50,7 +53,9 @@ function TooltipContent({
{...props}
>
{children}
<TooltipPrimitive.Arrow className="bg-main-view-fg fill-main-view-fg z-50 size-2.5 translate-y-[calc(-50%_-_2px)] rotate-45 rounded-[2px]" />
{showArrow && (
<TooltipPrimitive.Arrow className="bg-main-view-fg fill-main-view-fg z-50 size-2.5 translate-y-[calc(-50%_-_2px)] rotate-45 rounded-[2px]" />
)}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
)

View File

@ -34,6 +34,9 @@ import { ModelLoader } from '@/containers/loaders/ModelLoader'
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
import { useServiceHub } from '@/hooks/useServiceHub'
import { useTools } from '@/hooks/useTools'
import { TokenCounter } from '@/components/TokenCounter'
import { useMessages } from '@/hooks/useMessages'
import { useShallow } from 'zustand/react/shallow'
type ChatInputProps = {
className?: string
@ -56,9 +59,21 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const setPrompt = usePrompt((state) => state.setPrompt)
const currentThreadId = useThreads((state) => state.currentThreadId)
const { t } = useTranslation()
const { spellCheckChatInput } = useGeneralSetting()
const spellCheckChatInput = useGeneralSetting(
(state) => state.spellCheckChatInput
)
const tokenCounterCompact = useGeneralSetting(
(state) => state.tokenCounterCompact
)
useTools()
// Get current thread messages for token counting
const threadMessages = useMessages(
useShallow((state) =>
currentThreadId ? state.messages[currentThreadId] : []
)
)
const maxRows = 10
const selectedModel = useModelProvider((state) => state.selectedModel)
@ -79,6 +94,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const [connectedServers, setConnectedServers] = useState<string[]>([])
const [isDragOver, setIsDragOver] = useState(false)
const [hasMmproj, setHasMmproj] = useState(false)
const [hasActiveModels, setHasActiveModels] = useState(false)
// Check for connected MCP servers
useEffect(() => {
@ -100,6 +116,28 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
return () => clearInterval(intervalId)
}, [serviceHub])
// Check for active models
useEffect(() => {
const checkActiveModels = async () => {
try {
const activeModels = await serviceHub
.models()
.getActiveModels('llamacpp')
setHasActiveModels(activeModels.length > 0)
} catch (error) {
console.error('Failed to get active models:', error)
setHasActiveModels(false)
}
}
checkActiveModels()
// Poll for active models every 3 seconds
const intervalId = setInterval(checkActiveModels, 3000)
return () => clearInterval(intervalId)
}, [serviceHub])
// Check for mmproj existence or vision capability when model changes
useEffect(() => {
const checkMmprojSupport = async () => {
@ -742,35 +780,51 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</div>
</div>
{streamingContent ? (
<Button
variant="destructive"
size="icon"
onClick={() =>
stopStreaming(currentThreadId ?? streamingContent.thread_id)
}
>
<IconPlayerStopFilled />
</Button>
) : (
<Button
variant={
!prompt.trim() && uploadedFiles.length === 0
? null
: 'default'
}
size="icon"
disabled={!prompt.trim() && uploadedFiles.length === 0}
data-test-id="send-message-button"
onClick={() => handleSendMesage(prompt)}
>
{streamingContent ? (
<span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" />
) : (
<ArrowRight className="text-primary-fg" />
<div className="flex items-center gap-2">
{selectedProvider === 'llamacpp' &&
hasActiveModels &&
tokenCounterCompact &&
!initialMessage &&
(threadMessages?.length > 0 || prompt.trim().length > 0) && (
<div className="flex-1 flex justify-center">
<TokenCounter
messages={threadMessages || []}
compact={true}
uploadedFiles={uploadedFiles}
/>
</div>
)}
</Button>
)}
{streamingContent ? (
<Button
variant="destructive"
size="icon"
onClick={() =>
stopStreaming(currentThreadId ?? streamingContent.thread_id)
}
>
<IconPlayerStopFilled />
</Button>
) : (
<Button
variant={
!prompt.trim() && uploadedFiles.length === 0
? null
: 'default'
}
size="icon"
disabled={!prompt.trim() && uploadedFiles.length === 0}
data-test-id="send-message-button"
onClick={() => handleSendMesage(prompt)}
>
{streamingContent ? (
<span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" />
) : (
<ArrowRight className="text-primary-fg" />
)}
</Button>
)}
</div>
</div>
</div>
</div>
@ -792,6 +846,20 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</div>
</div>
)}
{selectedProvider === 'llamacpp' &&
hasActiveModels &&
!tokenCounterCompact &&
!initialMessage &&
(threadMessages?.length > 0 || prompt.trim().length > 0) && (
<div className="flex-1 w-full flex justify-start px-2">
<TokenCounter
messages={threadMessages || []}
compact={false}
uploadedFiles={uploadedFiles}
/>
</div>
)}
</div>
)
}

View File

@ -0,0 +1,17 @@
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { Switch } from '@/components/ui/switch'
export function TokenCounterCompactSwitcher() {
const { tokenCounterCompact, setTokenCounterCompact } = useGeneralSetting()
const toggleTokenCounterCompact = () => {
setTokenCounterCompact(!tokenCounterCompact)
}
return (
<Switch
checked={tokenCounterCompact}
onCheckedChange={toggleTokenCounterCompact}
/>
)
}

View File

@ -6,9 +6,11 @@ import { ExtensionManager } from '@/lib/extension'
type LeftPanelStoreState = {
currentLanguage: Language
spellCheckChatInput: boolean
tokenCounterCompact: boolean
huggingfaceToken?: string
setHuggingfaceToken: (token: string) => void
setSpellCheckChatInput: (value: boolean) => void
setTokenCounterCompact: (value: boolean) => void
setCurrentLanguage: (value: Language) => void
}
@ -17,8 +19,10 @@ export const useGeneralSetting = create<LeftPanelStoreState>()(
(set) => ({
currentLanguage: 'en',
spellCheckChatInput: true,
tokenCounterCompact: true,
huggingfaceToken: undefined,
setSpellCheckChatInput: (value) => set({ spellCheckChatInput: value }),
setTokenCounterCompact: (value) => set({ tokenCounterCompact: value }),
setCurrentLanguage: (value) => set({ currentLanguage: value }),
setHuggingfaceToken: (token) => {
set({ huggingfaceToken: token })

View File

@ -0,0 +1,200 @@
import { useCallback, useState, useRef, useEffect, useMemo } from 'react'
import { ThreadMessage, ContentType } from '@janhq/core'
import { useServiceHub } from './useServiceHub'
import { useModelProvider } from './useModelProvider'
import { usePrompt } from './usePrompt'
export interface TokenCountData {
tokenCount: number
maxTokens?: number
percentage?: number
isNearLimit: boolean
loading: boolean
error?: string
}
export const useTokensCount = (
messages: ThreadMessage[] = [],
uploadedFiles?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
) => {
const [tokenData, setTokenData] = useState<TokenCountData>({
tokenCount: 0,
loading: false,
isNearLimit: false,
})
const debounceTimeoutRef = useRef<NodeJS.Timeout | undefined>(undefined)
const isIncreasingContextSize = useRef<boolean>(false)
const serviceHub = useServiceHub()
const { selectedModel, selectedProvider } = useModelProvider()
const { prompt } = usePrompt()
// Create messages with current prompt for live calculation
const messagesWithPrompt = useMemo(() => {
const result = [...messages]
if (prompt.trim() || (uploadedFiles && uploadedFiles.length > 0)) {
const content = []
// Add text content if prompt exists
if (prompt.trim()) {
content.push({ type: ContentType.Text, text: { value: prompt } })
}
// Add image content for uploaded files
if (uploadedFiles && uploadedFiles.length > 0) {
uploadedFiles.forEach((file) => {
content.push({
type: ContentType.Image,
image_url: {
url: file.dataUrl,
detail: 'high', // Default to high detail for token calculation
},
})
})
}
if (content.length > 0) {
result.push({
id: 'temp-prompt',
thread_id: '',
role: 'user',
content,
created_at: Date.now(),
} as ThreadMessage)
}
}
return result
}, [messages, prompt, uploadedFiles])
// Debounced calculation that includes current prompt
const debouncedCalculateTokens = useCallback(async () => {
const modelId = selectedModel?.id
if (!modelId || selectedProvider !== 'llamacpp') {
setTokenData({
tokenCount: 0,
loading: false,
isNearLimit: false,
})
return
}
// Use messages with current prompt for calculation
const messagesToCalculate = messagesWithPrompt
if (messagesToCalculate.length === 0) {
setTokenData({
tokenCount: 0,
loading: false,
isNearLimit: false,
})
return
}
setTokenData((prev) => ({ ...prev, loading: true, error: undefined }))
try {
const tokenCount = await serviceHub
.models()
.getTokensCount(modelId, messagesToCalculate)
const maxTokensValue =
selectedModel?.settings?.ctx_len?.controller_props?.value
const maxTokensNum =
typeof maxTokensValue === 'string'
? parseInt(maxTokensValue)
: typeof maxTokensValue === 'number'
? maxTokensValue
: undefined
const percentage = maxTokensNum
? (tokenCount / maxTokensNum) * 100
: undefined
const isNearLimit = percentage ? percentage > 85 : false
setTokenData({
tokenCount,
maxTokens: maxTokensNum,
percentage,
isNearLimit,
loading: false,
})
} catch (error) {
console.error('Failed to calculate tokens:', error)
setTokenData((prev) => ({
...prev,
loading: false,
error:
error instanceof Error ? error.message : 'Failed to calculate tokens',
}))
}
}, [
selectedModel?.id,
selectedProvider,
messagesWithPrompt,
serviceHub,
selectedModel?.settings?.ctx_len?.controller_props?.value,
])
// Debounced effect that triggers when prompt or messages change
useEffect(() => {
// Clear existing timeout
if (debounceTimeoutRef.current) {
clearTimeout(debounceTimeoutRef.current)
}
// Skip calculation if we're currently increasing context size
if (isIncreasingContextSize.current) {
return
}
// Only calculate if we have messages or a prompt
if (
messagesWithPrompt.length > 0 &&
selectedProvider === 'llamacpp' &&
selectedModel?.id
) {
debounceTimeoutRef.current = setTimeout(() => {
debouncedCalculateTokens()
}, 150) // 150ms debounce for more responsive updates
} else {
// Reset immediately if no content
setTokenData({
tokenCount: 0,
loading: false,
isNearLimit: false,
})
}
return () => {
if (debounceTimeoutRef.current) {
clearTimeout(debounceTimeoutRef.current)
}
}
}, [
prompt,
messages.length,
selectedModel?.id,
selectedProvider,
messagesWithPrompt.length,
debouncedCalculateTokens,
])
// Manual calculation function (for click events)
const calculateTokens = useCallback(async () => {
// Trigger the debounced calculation immediately
if (debounceTimeoutRef.current) {
clearTimeout(debounceTimeoutRef.current)
}
await debouncedCalculateTokens()
}, [debouncedCalculateTokens])
return {
...tokenData,
calculateTokens,
}
}

View File

@ -100,6 +100,8 @@
"resetAppearanceSuccessDesc": "Alle Darstellungseinstellungen wurden auf die Standardeinstellungen zurückgesetzt.",
"chatWidth": "Chat Breite",
"chatWidthDesc": "Passe die Breite der Chatansicht an.",
"tokenCounterCompact": "Kompakter Token-Zähler",
"tokenCounterCompactDesc": "Token-Zähler im Chat-Eingabefeld anzeigen. Wenn deaktiviert, wird der Token-Zähler unter dem Eingabefeld angezeigt.",
"codeBlockTitle": "Code Block",
"codeBlockDesc": "Wähle einen Stil zur Syntaxhervorhebung.",
"showLineNumbers": "Zeilennummern anzeigen",

View File

@ -100,6 +100,8 @@
"resetAppearanceSuccessDesc": "All appearance settings have been restored to default.",
"chatWidth": "Chat Width",
"chatWidthDesc": "Customize the width of the chat view.",
"tokenCounterCompact": "Compact Token Counter",
"tokenCounterCompactDesc": "Show token counter inside chat input. When disabled, token counter appears below the input.",
"codeBlockTitle": "Code Block",
"codeBlockDesc": "Choose a syntax highlighting style.",
"showLineNumbers": "Show Line Numbers",

View File

@ -100,6 +100,8 @@
"resetAppearanceSuccessDesc": "Tất cả cài đặt giao diện đã được khôi phục về mặc định.",
"chatWidth": "Chiều rộng trò chuyện",
"chatWidthDesc": "Tùy chỉnh chiều rộng của chế độ xem trò chuyện.",
"tokenCounterCompact": "Bộ đếm token nhỏ gọn",
"tokenCounterCompactDesc": "Hiển thị bộ đếm token bên trong ô nhập trò chuyện. Khi tắt, bộ đếm token sẽ xuất hiện bên dưới ô nhập.",
"codeBlockTitle": "Khối mã",
"codeBlockDesc": "Chọn kiểu tô sáng cú pháp.",
"showLineNumbers": "Hiển thị số dòng",

View File

@ -100,6 +100,8 @@
"resetAppearanceSuccessDesc": "所有外观设置已恢复为默认值。",
"chatWidth": "聊天宽度",
"chatWidthDesc": "自定义聊天视图的宽度。",
"tokenCounterCompact": "紧凑令牌计数器",
"tokenCounterCompactDesc": "在聊天输入框内显示令牌计数器。禁用时,令牌计数器显示在输入框下方。",
"codeBlockTitle": "代码块",
"codeBlockDesc": "选择语法高亮样式。",
"showLineNumbers": "显示行号",
@ -264,4 +266,3 @@
"updateError": "更新 Llamacpp 失败"
}
}

View File

@ -19,6 +19,7 @@ import { LineNumbersSwitcher } from '@/containers/LineNumbersSwitcher'
import { CodeBlockExample } from '@/containers/CodeBlockExample'
import { toast } from 'sonner'
import { ChatWidthSwitcher } from '@/containers/ChatWidthSwitcher'
import { TokenCounterCompactSwitcher } from '@/containers/TokenCounterCompactSwitcher'
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.appearance as any)({
@ -115,6 +116,11 @@ function Appareances() {
description={t('settings:appearance.chatWidthDesc')}
/>
<ChatWidthSwitcher />
<CardItem
title={t('settings:appearance.tokenCounterCompact')}
description={t('settings:appearance.tokenCounterCompactDesc')}
actions={<TokenCounterCompactSwitcher />}
/>
</Card>
{/* Codeblock */}

View File

@ -9,6 +9,8 @@ import {
SessionInfo,
SettingComponentProps,
modelInfo,
ThreadMessage,
ContentType,
} from '@janhq/core'
import { Model as CoreModel } from '@janhq/core'
import type {
@ -544,4 +546,113 @@ export class DefaultModelsService implements ModelsService {
}
}
}
async getTokensCount(
modelId: string,
messages: ThreadMessage[]
): Promise<number> {
try {
const engine = this.getEngine('llamacpp') as AIEngine & {
getTokensCount?: (opts: {
model: string
messages: Array<{
role: string
content:
| string
| Array<{
type: string
text?: string
image_url?: {
detail?: string
url?: string
}
}>
}>
}) => Promise<number>
}
if (engine && typeof engine.getTokensCount === 'function') {
// Transform Jan's ThreadMessage format to OpenAI chat completion format
const transformedMessages = messages
.map((message) => {
// Handle different content types
let content:
| string
| Array<{
type: string
text?: string
image_url?: {
detail?: string
url?: string
}
}> = ''
if (message.content && message.content.length > 0) {
// Check if there are any image_url content types
const hasImages = message.content.some(
(content) => content.type === ContentType.Image
)
if (hasImages) {
// For multimodal messages, preserve the array structure
content = message.content.map((contentItem) => {
if (contentItem.type === ContentType.Text) {
return {
type: 'text',
text: contentItem.text?.value || '',
}
} else if (contentItem.type === ContentType.Image) {
return {
type: 'image_url',
image_url: {
detail: contentItem.image_url?.detail,
url: contentItem.image_url?.url || '',
},
}
}
// Fallback for unknown content types
return {
type: contentItem.type,
text: contentItem.text?.value,
image_url: contentItem.image_url,
}
})
} else {
// For text-only messages, keep the string format
const textContents = message.content
.filter(
(content) =>
content.type === ContentType.Text && content.text?.value
)
.map((content) => content.text?.value || '')
content = textContents.join(' ')
}
}
return {
role: message.role,
content,
}
})
.filter((msg) =>
typeof msg.content === 'string'
? msg.content.trim() !== ''
: Array.isArray(msg.content) && msg.content.length > 0
) // Filter out empty messages
return await engine.getTokensCount({
model: modelId,
messages: transformedMessages,
})
}
// Fallback if method is not available
console.warn('getTokensCount method not available in llamacpp engine')
return 0
} catch (error) {
console.error(`Error getting tokens count for model ${modelId}:`, error)
return 0
}
}
}

View File

@ -2,7 +2,7 @@
* Models Service Types
*/
import { SessionInfo, modelInfo } from '@janhq/core'
import { SessionInfo, modelInfo, ThreadMessage } from '@janhq/core'
import { Model as CoreModel } from '@janhq/core'
// Types for model catalog
@ -142,4 +142,5 @@ export interface ModelsService {
mmprojPath?: string,
requestedCtx?: number
): Promise<ModelPlan>
getTokensCount(modelId: string, messages: ThreadMessage[]): Promise<number>
}