diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..600e365ec --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +**/node_modules \ No newline at end of file diff --git a/.github/workflows/jan-electron-build.yml b/.github/workflows/jan-electron-build.yml index c6a2df1b9..339365625 100644 --- a/.github/workflows/jan-electron-build.yml +++ b/.github/workflows/jan-electron-build.yml @@ -52,7 +52,7 @@ jobs: - name: Install yarn dependencies run: | yarn install - yarn build:plugins-darwin + yarn build:pull-plugins env: APP_PATH: "." DEVELOPER_ID: ${{ secrets.DEVELOPER_ID }} @@ -104,7 +104,7 @@ jobs: run: | yarn config set network-timeout 300000 yarn install - yarn build:plugins + yarn build:pull-plugins - name: Build and publish app run: | @@ -153,7 +153,7 @@ jobs: run: | yarn config set network-timeout 300000 yarn install - yarn build:plugins + yarn build:pull-plugins - name: Build and publish app run: | diff --git a/.github/workflows/jan-electron-linter-and-test.yml b/.github/workflows/jan-electron-linter-and-test.yml index e9e3f091a..8c0960463 100644 --- a/.github/workflows/jan-electron-linter-and-test.yml +++ b/.github/workflows/jan-electron-linter-and-test.yml @@ -46,8 +46,8 @@ jobs: yarn config set network-timeout 300000 yarn install yarn lint - yarn build:plugins - yarn build + yarn build:pull-plugins + yarn build:test yarn test env: CSC_IDENTITY_AUTO_DISCOVERY: "false" @@ -76,9 +76,8 @@ jobs: run: | yarn config set network-timeout 300000 yarn install - yarn lint - yarn build:plugins - yarn build:win32 + yarn build:pull-plugins + yarn build:test-win32 yarn test test-on-ubuntu: @@ -105,7 +104,6 @@ jobs: echo -e "Display ID: $DISPLAY" yarn config set network-timeout 300000 yarn install - yarn lint - yarn build:plugins - yarn build:linux + yarn build:pull-plugins + yarn build:test-linux yarn test \ No newline at end of file diff --git a/.gitignore b/.gitignore index a190afdbf..dc634deb1 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,9 @@ package-lock.json *.log plugin-core/lib core/lib/** + +# Nitro binary files +plugins/inference-plugin/nitro/*/nitro +plugins/inference-plugin/nitro/*/*.exe +plugins/inference-plugin/nitro/*/*.dll +plugins/inference-plugin/nitro/*/*.metal \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..949a92673 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +FROM node:20-bullseye AS base + +# 1. Install dependencies only when needed +FROM base AS deps +WORKDIR /app + +# Install dependencies based on the preferred package manager +COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* ./ +RUN yarn install + +# # 2. Rebuild the source code only when needed +FROM base AS builder +WORKDIR /app +COPY --from=deps /app/node_modules ./node_modules +COPY . . +# This will do the trick, use the corresponding env file for each environment. +RUN yarn workspace server install +RUN yarn server:prod + +# 3. Production image, copy all the files and run next +FROM base AS runner +WORKDIR /app + +ENV NODE_ENV=production + +# RUN addgroup -g 1001 -S nodejs; +COPY --from=builder /app/server/build ./ + +# Automatically leverage output traces to reduce image size +# https://nextjs.org/docs/advanced-features/output-file-tracing +COPY --from=builder /app/server/node_modules ./node_modules +COPY --from=builder /app/server/package.json ./package.json + +EXPOSE 4000 3928 + +ENV PORT 4000 +ENV APPDATA /app/data + +CMD ["node", "main.js"] \ No newline at end of file diff --git a/adr/Makefile b/adr/Makefile deleted file mode 100644 index edbe801ea..000000000 --- a/adr/Makefile +++ /dev/null @@ -1,14 +0,0 @@ -newadr: - @echo "Initiating an ADR..." - @last_number=$$(ls $(CURDIR)/adr-[0-9][0-9][0-9]-* | sort -V | tail -n 1 | cut -d '-' -f 2); \ - last_number=$$(echo $$last_number | sed 's/^0*//'); \ - next_number=$$(printf "%03d" $$(( $$last_number + 1 ))); \ - read -p "Enter ADR title: " title; \ - title=$$(echo $$title | tr ' ' '-'); \ - cp $(CURDIR)/adr-template.md $(CURDIR)/adr-$$next_number-$$title.md; \ - date=$$(date +%Y-%m-%d); \ - usernames=$$(git config user.name); \ - sed -i '' 's/{ADR-NUM}/'$$next_number'/g' $(CURDIR)/adr-$$next_number-$$title.md; \ - sed -i '' 's/{TITLE}/'$$title'/g' $(CURDIR)/adr-$$next_number-$$title.md; \ - sed -i '' 's/{DATE}/'$$date'/g' $(CURDIR)/adr-$$next_number-$$title.md; \ - sed -i '' 's/{USERNAMES}/'$$usernames'/g' $(CURDIR)/adr-$$next_number-$$title.md \ No newline at end of file diff --git a/adr/adr-template.md b/adr/adr-template.md deleted file mode 100644 index 61e14f928..000000000 --- a/adr/adr-template.md +++ /dev/null @@ -1,31 +0,0 @@ -# ADR {ADR-NUM}: {TITLE} - -## Changelog - -- {DATE}: Initial draft - -## Authors - -- {USERNAMES} - -## Status - -What is the status, such as proposed, accepted, rejected, deprecated, superseded, etc.? - -{Proposed|Accepted|Rejected} - -## Context - -What is the issue that we're seeing that is motivating this decision or change? - -## Decision - -What is the change that we're proposing and/or doing? - -## Consequences - -What becomes easier or more difficult to do because of this change? - -## Alternatives - -## Reference diff --git a/core/core.ts b/core/core.ts index 84b2524e6..1820d5b9a 100644 --- a/core/core.ts +++ b/core/core.ts @@ -28,6 +28,13 @@ const downloadFile: (url: string, fileName: string) => Promise = (url, file const deleteFile: (path: string) => Promise = (path) => window.coreAPI?.deleteFile(path) ?? window.electronAPI?.deleteFile(path); +/** + * Retrieves the path to the app data directory using the `coreAPI` object. + * If the `coreAPI` object is not available, the function returns `undefined`. + * @returns A Promise that resolves with the path to the app data directory, or `undefined` if the `coreAPI` object is not available. + */ +const appDataPath: () => Promise = () => window.coreAPI?.appDataPath(); + /** Register extension point function type definition * */ @@ -46,9 +53,10 @@ export const core = { invokePluginFunc, downloadFile, deleteFile, + appDataPath, }; /** * Functions exports */ -export { invokePluginFunc, downloadFile, deleteFile }; +export { invokePluginFunc, downloadFile, deleteFile, appDataPath }; diff --git a/core/package.json b/core/package.json index c74f87800..d80e89a6c 100644 --- a/core/package.json +++ b/core/package.json @@ -1,6 +1,6 @@ { "name": "@janhq/core", - "version": "0.1.8", + "version": "0.1.9", "description": "Plugin core lib", "keywords": [ "jan", diff --git a/docs/src/pages/index.js b/docs/src/pages/index.js index 1325920b8..f6616599e 100644 --- a/docs/src/pages/index.js +++ b/docs/src/pages/index.js @@ -87,12 +87,19 @@ export default function Home() { Run Your Own AI

- Run Large Language Models locally on  + {/* Run Large Language Models locally on  MacWindows  or  - Linux. + Linux. */} + Jan is a powerful  + Personal AI +  built to run locally on your machine, +
+ with a rich  + app and  + plugin ecosystem.

diff --git a/electron/core/plugin-manager/execution/ExtensionPoint.test.js b/electron/core/plugin-manager/execution/ExtensionPoint.test.js deleted file mode 100644 index 54d7b37ec..000000000 --- a/electron/core/plugin-manager/execution/ExtensionPoint.test.js +++ /dev/null @@ -1,116 +0,0 @@ -import Ep from './ExtensionPoint' - -/** @type {Ep} */ -let ep -const changeListener = jest.fn() - -const objectRsp = { foo: 'bar' } -const funcRsp = arr => { - arr || (arr = []) - arr.push({ foo: 'baz' }) - return arr -} - -beforeEach(() => { - ep = new Ep('test-ep') - ep.register('test-ext-obj', objectRsp) - ep.register('test-ext-func', funcRsp, 10) - ep.onRegister('test', changeListener) -}) - - -it('should create a new extension point by providing a name', () => { - expect(ep.name).toEqual('test-ep') -}) - -it('should register extension with extension point', () => { - expect(ep._extensions).toContainEqual({ - name: 'test-ext-func', - response: funcRsp, - priority: 10 - }) -}) - -it('should register extension with a default priority of 0 if not provided', () => { - expect(ep._extensions).toContainEqual({ - name: 'test-ext-obj', - response: objectRsp, - priority: 0 - }) -}) - -it('should execute the change listeners on registering a new extension', () => { - changeListener.mockClear() - ep.register('test-change-listener', true) - expect(changeListener.mock.calls.length).toBeTruthy() -}) - -it('should unregister an extension with the provided name if it exists', () => { - ep.unregister('test-ext-obj') - - expect(ep._extensions).not.toContainEqual( - expect.objectContaining({ - name: 'test-ext-obj' - }) - ) -}) - -it('should not unregister any extensions if the provided name does not exist', () => { - ep.unregister('test-ext-invalid') - - expect(ep._extensions.length).toBe(2) -}) - -it('should execute the change listeners on unregistering an extension', () => { - changeListener.mockClear() - ep.unregister('test-ext-obj') - expect(changeListener.mock.calls.length).toBeTruthy() -}) - -it('should empty the registry of all extensions on clearing', () => { - ep.clear() - - expect(ep._extensions).toEqual([]) -}) - -it('should execute the change listeners on clearing extensions', () => { - changeListener.mockClear() - ep.clear() - expect(changeListener.mock.calls.length).toBeTruthy() -}) - -it('should return the relevant extension using the get method', () => { - const ext = ep.get('test-ext-obj') - - expect(ext).toEqual({ foo: 'bar' }) -}) - -it('should return the false using the get method if the extension does not exist', () => { - const ext = ep.get('test-ext-invalid') - - expect(ext).toBeUndefined() -}) - -it('should provide an array with all responses, including promises where necessary, using the execute method', async () => { - ep.register('test-ext-async', () => new Promise(resolve => setTimeout(resolve, 0, { foo: 'delayed' }))) - const arr = ep.execute([]) - - const res = await Promise.all(arr) - - expect(res).toContainEqual({ foo: 'bar' }) - expect(res).toContainEqual([{ foo: 'baz' }]) - expect(res).toContainEqual({ foo: 'delayed' }) - expect(res.length).toBe(3) -}) - -it('should provide an array including all responses in priority order, using the executeSerial method provided with an array', async () => { - const res = await ep.executeSerial([]) - - expect(res).toEqual([{ "foo": "bar" }, { "foo": "baz" }]) -}) - -it('should provide an array including the last response using the executeSerial method provided with something other than an array', async () => { - const res = await ep.executeSerial() - - expect(res).toEqual([{ "foo": "baz" }]) -}) diff --git a/electron/core/plugin-manager/execution/Plugin.test.js b/electron/core/plugin-manager/execution/Plugin.test.js deleted file mode 100644 index 7982decdc..000000000 --- a/electron/core/plugin-manager/execution/Plugin.test.js +++ /dev/null @@ -1,22 +0,0 @@ -import { setImporter } from "./import-manager" -import Plugin from './Plugin' - -describe('triggerExport', () => { - it('should call the provided export on the plugin\'s main file', async () => { - // Set up mock importer with mock main plugin file - const mockExport = jest.fn() - const mockImporter = jest.fn(() => ({ - lifeCycleFn: mockExport - })) - setImporter(mockImporter) - - // Call triggerExport on new plugin - const plgUrl = 'main' - const plugin = new Plugin('test', plgUrl, ['ap1'], true) - await plugin.triggerExport('lifeCycleFn') - - // Check results - expect(mockImporter.mock.lastCall).toEqual([plgUrl]) - expect(mockExport.mock.calls.length).toBeTruthy() - }) -}) \ No newline at end of file diff --git a/electron/core/plugin-manager/execution/activation-manager.test.js b/electron/core/plugin-manager/execution/activation-manager.test.js deleted file mode 100644 index dd7ff9251..000000000 --- a/electron/core/plugin-manager/execution/activation-manager.test.js +++ /dev/null @@ -1,307 +0,0 @@ -import { setup } from './index' -import { register, trigger, remove, clear, get } from "./activation-manager"; -import { add } from './extension-manager' - -let mockPlugins = {} -setup({ - importer(plugin) { return mockPlugins[plugin] } -}) - -afterEach(() => { - clear() - mockPlugins = {} -}) - -describe('register', () => { - it('should add a new activation point to the register when a new, valid plugin is registered', - () => { - register({ - name: 'test', - url: 'testPkg', - activationPoints: ['ap1', 'ap2'], - active: true - }) - - expect(get()).toEqual([ - { - plugin: 'test', - url: 'testPkg', - activationPoint: 'ap1', - activated: false - }, - { - plugin: 'test', - url: 'testPkg', - activationPoint: 'ap2', - activated: false - } - ]) - } - ) - - it('should not add an activation point to the register when an existing, valid plugin is registered', - () => { - register({ - name: 'test', - url: 'testPkg', - activationPoints: ['ap1', 'ap2'], - active: true - }) - - register({ - name: 'test', - url: 'testPkg', - activationPoints: ['ap2', 'ap3'], - active: true - }) - - expect(get()).toEqual([ - { - plugin: 'test', - url: 'testPkg', - activationPoint: 'ap1', - activated: false - }, - { - plugin: 'test', - url: 'testPkg', - activationPoint: 'ap2', - activated: false - }, - { - plugin: 'test', - url: 'testPkg', - activationPoint: 'ap3', - activated: false - }, - ]) - } - ) - - it('should throw an error when an invalid plugin is registered', - () => { - const noActivationPoints = () => register({ - name: 'test', - url: 'testPkg', - active: true - }) - - expect(noActivationPoints).toThrow(/does not have any activation points set up in its manifest/) - } - ) -}) - -describe('trigger', () => { - it('should trigger all and only the activations with for the given execution point on triggering an execution, using the defined importer', - async () => { - const triggered = [] - - mockPlugins.plugin1 = { - ap1() { triggered.push('plugin1-ap1') } - } - mockPlugins.plugin2 = { - ap2() { triggered.push('plugin2-ap2') } - } - mockPlugins.plugin3 = { - ap1() { triggered.push('plugin3-ap1') }, - ap2() { triggered.push('plugin3-ap2') } - } - - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1'], - active: true - }) - register({ - name: 'plugin2', - url: 'plugin2', - activationPoints: ['ap2'], - active: true - }) - register({ - name: 'plugin3', - url: 'plugin3', - activationPoints: ['ap1', 'ap2'], - active: true - }) - - await trigger('ap1') - - expect(triggered).toEqual(['plugin1-ap1', 'plugin3-ap1']) - } - ) - - it('should return an error if an activation point is triggered on a plugin that does not include it', - async () => { - mockPlugins.plugin1 = { - wrongAp() { } - } - - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1'] - }) - - await expect(() => trigger('ap1')).rejects.toThrow(/was triggered but does not exist on plugin/) - } - ) - - it('should provide the registered extension points to the triggered activation point if presetEPs is set to true in the setup', - async () => { - setup({ - importer(plugin) { return mockPlugins[plugin] }, - presetEPs: true, - }) - - let ap1Res - - mockPlugins.plugin1 = { - ap1: eps => ap1Res = eps - } - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1'] - }) - - add('ep1') - add('ep2') - - await trigger('ap1') - - expect(ap1Res.ep1.constructor.name).toEqual('ExtensionPoint') - expect(ap1Res.ep2.constructor.name).toEqual('ExtensionPoint') - } - ) - - it('should allow registration, execution and serial execution of execution points when an activation point is triggered if presetEPs is set to false in the setup', - async () => { - setup({ - importer(plugin) { return mockPlugins[plugin] }, - }) - - let ap1Res - - mockPlugins.plugin1 = { - ap1: eps => ap1Res = eps - } - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1'] - }) - - await trigger('ap1') - - expect(typeof ap1Res.register).toBe('function') - expect(typeof ap1Res.execute).toBe('function') - expect(typeof ap1Res.executeSerial).toBe('function') - } - ) - - it('should not provide any reference to extension points during activation point triggering if presetEPs is set to null in the setup', - async () => { - setup({ - importer(plugin) { return mockPlugins[plugin] }, - presetEPs: null, - }) - - let ap1Res = true - - mockPlugins.plugin1 = { - ap1: eps => ap1Res = eps - } - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1'] - }) - - await trigger('ap1') - - expect(ap1Res).not.toBeDefined() - } - ) -}) - -describe('remove and clear', () => { - - beforeEach(() => { - register({ - name: 'plugin1', - url: 'plugin1', - activationPoints: ['ap1', 'ap2'], - active: true - }) - - register({ - name: 'plugin2', - url: 'plugin2', - activationPoints: ['ap2', 'ap3'], - active: true - }) - }) - it('should remove all and only the activations for the given plugin from the register when removing activations', - () => { - remove('plugin1') - - expect(get()).toEqual([ - { - plugin: 'plugin2', - url: 'plugin2', - activationPoint: 'ap2', - activated: false - }, - { - plugin: 'plugin2', - url: 'plugin2', - activationPoint: 'ap3', - activated: false - }, - ]) - } - ) - - it('should not remove any activations from the register if no plugin name is provided', - () => { - remove() - - expect(get()).toEqual([ - { - plugin: 'plugin1', - url: 'plugin1', - activationPoint: 'ap1', - activated: false - }, - { - plugin: 'plugin1', - url: 'plugin1', - activationPoint: 'ap2', - activated: false - }, - { - plugin: 'plugin2', - url: 'plugin2', - activationPoint: 'ap2', - activated: false - }, - { - plugin: 'plugin2', - url: 'plugin2', - activationPoint: 'ap3', - activated: false - }, - ]) - } - ) - - it('should remove all activations from the register when clearing the register', - () => { - clear() - - expect(get()).toEqual([]) - } - ) -}) diff --git a/electron/core/plugin-manager/execution/extension-manager.test.js b/electron/core/plugin-manager/execution/extension-manager.test.js deleted file mode 100644 index dd934f872..000000000 --- a/electron/core/plugin-manager/execution/extension-manager.test.js +++ /dev/null @@ -1,116 +0,0 @@ -import { add, remove, register, get, execute, executeSerial, unregisterAll } from './extension-manager' -import ExtensionPoint from './ExtensionPoint' - -beforeEach(() => { - add('ep1') - add('ep2') -}) - -afterEach(() => { - remove('ep1') - remove('ep2') - remove('ep3') -}) - -describe('get', () => { - it('should return the extension point with the given name if it exists', () => { - expect(get('ep1')).toBeInstanceOf(ExtensionPoint) - }) - - it('should return all extension points if no name is provided', () => { - expect(get()).toEqual(expect.objectContaining({ ep1: expect.any(ExtensionPoint) })) - expect(get()).toEqual(expect.objectContaining({ ep2: expect.any(ExtensionPoint) })) - }) -}) - -describe('Add and remove', () => { - it('should add a new extension point with the given name using the add function', () => { - add('ep1') - - expect(get('ep1')).toBeInstanceOf(ExtensionPoint) - }) - - it('should remove only the extension point with the given name using the remove function', () => { - remove('ep1') - - expect(get()).not.toEqual(expect.objectContaining({ ep1: expect.anything() })) - expect(get()).toEqual(expect.objectContaining({ ep2: expect.any(ExtensionPoint) })) - }) - - it('should not remove any extension points if no name is provided using the remove function', () => { - remove() - - expect(get()).toEqual(expect.objectContaining({ ep1: expect.any(ExtensionPoint) })) - expect(get()).toEqual(expect.objectContaining({ ep2: expect.any(ExtensionPoint) })) - }) -}) - -describe('register', () => { - it('should register an extension to an existing extension point if the point has already been created', () => { - register('ep1', 'extension1', { foo: 'bar' }) - - expect(get('ep1')._extensions).toContainEqual(expect.objectContaining({ name: 'extension1' })) - }) - - it('should create an extension point and register an extension to it if the point has not yet been created', () => { - register('ep3', 'extension1', { foo: 'bar' }) - - expect(get('ep3')._extensions).toContainEqual(expect.objectContaining({ name: 'extension1' })) - }) -}) - -describe('unregisterAll', () => { - it('should unregister all extension points matching the give name regex', () => { - // Register example extensions - register('ep1', 'remove1', { foo: 'bar' }) - register('ep2', 'remove2', { foo: 'bar' }) - register('ep1', 'keep', { foo: 'bar' }) - - // Remove matching extensions - unregisterAll(/remove/) - - // Extract all registered extensions - const eps = Object.values(get()).map(ep => ep._extensions) - const extensions = eps.flat() - - // Test extracted extensions - expect(extensions).toContainEqual(expect.objectContaining({ name: 'keep' })) - expect(extensions).not.toContainEqual(expect.objectContaining({ name: 'ep1' })) - expect(extensions).not.toContainEqual(expect.objectContaining({ name: 'ep2' })) - }) -}) - -describe('execute', () => { - it('should execute the extensions registered to the named extension point with the provided input', () => { - const result = [] - register('ep1', 'extension1', input => result.push(input + 'bar')) - register('ep1', 'extension2', input => result.push(input + 'baz')) - - execute('ep1', 'foo') - - expect(result).toEqual(['foobar', 'foobaz']) - }) - - it('should throw an error if the named extension point does not exist', () => { - register('ep1', 'extension1', { foo: 'bar' }) - - expect(() => execute('ep3')).toThrow(/not a valid extension point/) - }) -}) - -describe('executeSerial', () => { - it('should execute the extensions in serial registered to the named extension point with the provided input', async () => { - register('ep1', 'extension1', input => input + 'bar') - register('ep1', 'extension2', input => input + 'baz') - - const result = await executeSerial('ep1', 'foo') - - expect(result).toEqual('foobarbaz') - }) - - it('should throw an error if the named extension point does not exist', () => { - register('ep1', 'extension1', { foo: 'bar' }) - - expect(() => executeSerial('ep3')).toThrow(/not a valid extension point/) - }) -}) diff --git a/electron/core/plugin-manager/execution/index.test.js b/electron/core/plugin-manager/execution/index.test.js deleted file mode 100644 index 75f8e5933..000000000 --- a/electron/core/plugin-manager/execution/index.test.js +++ /dev/null @@ -1,28 +0,0 @@ -import { setup } from "." -import { importer, presetEPs } from "./import-manager" - -describe('setup', () => { - const mockImporter = jest.fn() - - it('should store the importer function', () => { - setup({ importer: mockImporter }) - - expect(importer).toBe(mockImporter) - }) - - it('should set presetEPS to false if not provided', () => { - expect(presetEPs).toBe(false) - }) - - it('should set presetEPS to the provided value if it is true', () => { - setup({ presetEPs: true }) - - expect(presetEPs).toBe(true) - }) - - it('should set presetEPS to the provided value if it is null', () => { - setup({ presetEPs: null }) - - expect(presetEPs).toBe(null) - }) -}) \ No newline at end of file diff --git a/electron/core/plugin-manager/facade/index.test.js b/electron/core/plugin-manager/facade/index.test.js deleted file mode 100644 index 0e4bae62d..000000000 --- a/electron/core/plugin-manager/facade/index.test.js +++ /dev/null @@ -1,196 +0,0 @@ -jest.mock('electron', () => { - const handlers = {} - - return { - ipcMain: { - handle(channel, callback) { - handlers[channel] = callback - } - }, - ipcRenderer: { - invoke(channel, ...args) { - return Promise.resolve(handlers[channel].call(undefined, 'event', ...args)) - } - }, - webContents: { - getAllWebContents: jest.fn(() => []) - }, - contextBridge: { - exposeInMainWorld(key, val) { - global.window = { [key]: val } - } - } - } -}) - -jest.mock('../pluginMgr/store', () => { - const setActive = jest.fn(() => true) - const uninstall = jest.fn() - const update = jest.fn(() => true) - const isUpdateAvailable = jest.fn(() => false) - - class Plugin { - constructor(name) { - this.name = name - this.activationPoints = ['test'] - } - setActive = setActive - uninstall = uninstall - update = update - isUpdateAvailable = isUpdateAvailable - } - - return { - getPlugin: jest.fn(name => new Plugin(name)), - getActivePlugins: jest.fn(() => [new Plugin('test')]), - installPlugins: jest.fn(async plugins => plugins.map(name => new Plugin(name))), - removePlugin: jest.fn() - } -}) - -const { rmSync } = require('fs') -const { webContents } = require('electron') -const useFacade = require('./index') -const { getActive, install, toggleActive, uninstall, update, updatesAvailable, registerActive } = require('../execution/facade') -const { setPluginsPath, setConfirmInstall } = require('../pluginMgr/globals') -const router = require('../pluginMgr/router') -const { getPlugin, getActivePlugins, removePlugin } = require('../pluginMgr/store') -const { get: getActivations } = require('../execution/activation-manager') - -const pluginsPath = './testPlugins' -const confirmInstall = jest.fn(() => true) - -beforeAll(async () => { - setPluginsPath(pluginsPath) - router() - useFacade() -}) - -afterAll(() => { - rmSync(pluginsPath, { recursive: true }) -}) - -describe('install', () => { - it('should return cancelled state if the confirmPlugin callback returns falsy', async () => { - setConfirmInstall(() => false) - const plugins = await install(['test-install']) - expect(plugins).toEqual(false) - }) - - it('should perform a security check of the install using confirmInstall if facade is used', async () => { - setConfirmInstall(confirmInstall) - await install(['test-install']) - expect(confirmInstall.mock.calls.length).toBeTruthy() - }) - - it('should register all installed plugins', async () => { - const pluginName = 'test-install' - await install([pluginName]) - expect(getActivations()).toContainEqual(expect.objectContaining({ - plugin: pluginName - })) - }) - - it('should return a list of plugins', async () => { - setConfirmInstall(confirmInstall) - const pluginName = 'test-install' - const plugins = await install([pluginName]) - expect(plugins).toEqual([expect.objectContaining({ name: pluginName })]) - }) -}) - -describe('uninstall', () => { - it('should uninstall all plugins with the provided name, remove it from the store and refresh all renderers', async () => { - // Reset mock functions - const mockUninstall = getPlugin().uninstall - mockUninstall.mockClear() - removePlugin.mockClear() - webContents.getAllWebContents.mockClear() - getPlugin.mockClear() - - // Uninstall plugins - const specs = ['test-uninstall-1', 'test-uninstall-2'] - await uninstall(specs) - - // Test result - expect(getPlugin.mock.calls).toEqual(specs.map(spec => [spec])) - expect(mockUninstall.mock.calls.length).toBeTruthy() - expect(removePlugin.mock.calls.length).toBeTruthy() - expect(webContents.getAllWebContents.mock.calls.length).toBeTruthy() - }) -}) - -describe('getActive', () => { - it('should return all active plugins', async () => { - getActivePlugins.mockClear() - await getActive() - expect(getActivePlugins.mock.calls.length).toBeTruthy() - }) -}) - -describe('registerActive', () => { - it('should register all active plugins', async () => { - await registerActive() - expect(getActivations()).toContainEqual(expect.objectContaining({ - plugin: 'test' - })) - }) -}) - -describe('update', () => { - const specs = ['test-uninstall-1', 'test-uninstall-2'] - const mockUpdate = getPlugin().update - - beforeAll(async () => { - // Reset mock functions - mockUpdate.mockClear() - webContents.getAllWebContents.mockClear() - getPlugin.mockClear() - - // Update plugins - await update(specs) - }) - - it('should call the update function on all provided plugins', async () => { - // Check result - expect(getPlugin.mock.calls).toEqual(specs.map(spec => [spec])) - expect(mockUpdate.mock.calls.length).toBe(2) - }) - - it('should reload the renderers if reload is true', () => { - expect(webContents.getAllWebContents.mock.calls.length).toBeTruthy() - }) - - it('should not reload the renderer if reload is false', async () => { - webContents.getAllWebContents.mockClear() - await update(['test-uninstall'], false) - expect(webContents.getAllWebContents.mock.calls.length).toBeFalsy() - }) -}) - -describe('toggleActive', () => { - it('call the setActive function on the plugin with the provided name, with the provided active state', async () => { - await toggleActive('test-toggleActive', true) - expect(getPlugin.mock.lastCall).toEqual(['test-toggleActive']) - const mockSetActive = getPlugin().setActive - expect(mockSetActive.mock.lastCall).toEqual([true]) - }) -}) - -describe('updatesAvailable', () => { - it('should return the new versions for the provided plugins if provided', async () => { - // Reset mock functions - const mockIsUpdAvailable = getPlugin().isUpdateAvailable - mockIsUpdAvailable.mockClear() - getPlugin.mockClear() - - // Get available updates - const testPlugin1 = 'test-plugin-1' - const testPlugin2 = 'test-update-2' - const updates = await updatesAvailable([testPlugin1, testPlugin2]) - expect(updates).toEqual({ - [testPlugin1]: false, - [testPlugin2]: false, - }) - }) -}) diff --git a/electron/core/plugin-manager/pluginMgr/Plugin.test.js b/electron/core/plugin-manager/pluginMgr/Plugin.test.js deleted file mode 100644 index c0326bd06..000000000 --- a/electron/core/plugin-manager/pluginMgr/Plugin.test.js +++ /dev/null @@ -1,212 +0,0 @@ -import { init } from "." -import { join } from 'path' -import Plugin from "./Plugin" -import { mkdirSync, writeFileSync, existsSync, readFileSync, rmSync } from "fs" - -const pluginsDir = './testPlugins' -const testPluginDir = './testPluginSrc' -const testPluginName = 'test-plugin' -const manifest = join(testPluginDir, 'package.json') -const main = 'index' - -/** @type Plugin */ -let plugin - -beforeAll(() => { - init({ - confirmInstall: () => true, - pluginsPath: pluginsDir, - }) - - mkdirSync(testPluginDir) - - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - activationPoints: [], - main, - }), 'utf8') - - plugin = new Plugin(testPluginDir) -}) - -afterAll(() => { - rmSync(pluginsDir, { recursive: true }) - rmSync(testPluginDir, { recursive: true }) -}) - - -describe('subscribe', () => { - let res = false - it('should register the provided callback', () => { - plugin.subscribe('test', () => res = true) - plugin.setActive(true) - - expect(res).toBeTruthy() - }) -}) - -describe('unsubscribe', () => { - it(`should remove the provided callback from the register - after which it should not be executed anymore when the plugin is updated`, () => { - let res = false - plugin.subscribe('test', () => res = true) - plugin.unsubscribe('test') - plugin.setActive(true) - - expect(res).toBeFalsy() - }) -}) - -describe('install', () => { - beforeAll(async () => { - await plugin._install() - }) - - it('should store all the relevant manifest values on the plugin', async () => { - expect(plugin).toMatchObject({ - origin: testPluginDir, - installOptions: { - version: false, - fullMetadata: false, - }, - name: testPluginName, - url: `plugin://${testPluginName}/${main}`, - activationPoints: [] - }) - }) - - it('should create a folder for the plugin if it does not yet exist and copy the plugin files to it', () => { - expect(existsSync(join(pluginsDir, testPluginName))).toBeTruthy() - }) - - it('should replace the existing plugin files in the plugin folder if it already exist', async () => { - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - activationPoints: [], - main: 'updated', - }), 'utf8') - - await plugin._install() - - const savedPkg = JSON.parse(readFileSync(join(pluginsDir, testPluginName, 'package.json'))) - - expect(savedPkg.main).toBe('updated') - }) - - it('should throw an error and the plugin should be set to inactive if no manifest could be found', async () => { - rmSync(join(testPluginDir, 'package.json')) - - await expect(() => plugin._install()).rejects.toThrow(/does not contain a valid manifest/) - }) - - it('should throw an error and the plugin should be set to inactive if plugin does not contain any activation points', async () => { - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - main, - }), 'utf8') - - await expect(() => plugin._install()).rejects.toThrow('The plugin does not contain any activation points') - expect(plugin.active).toBe(false) - }) -}) - -describe('update', () => { - let updatedPlugin - let subscription = false - let beforeUpd - - beforeAll(async () => { - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - activationPoints: [], - version: '0.0.1', - main, - }), 'utf8') - - await plugin._install() - - plugin.subscribe('test', () => subscription = true) - beforeUpd = Object.assign({}, plugin) - - await plugin.update() - }) - - it('should not do anything if no version update is available', () => { - expect(beforeUpd).toMatchObject(plugin) - }) - - it('should update the plugin files to the latest version if there is a new version available for the plugin', async () => { - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - activationPoints: [], - version: '0.0.2', - main, - }), 'utf8') - - await plugin.update() - - expect(plugin).toMatchObject({ - origin: testPluginDir, - installOptions: { - version: false, - fullMetadata: false, - }, - name: testPluginName, - version: '0.0.2', - url: `plugin://${testPluginName}/${main}`, - activationPoints: [] - }) - }) - - it('should execute callbacks subscribed to this plugin, providing the plugin as a parameter', () => { - expect(subscription).toBeTruthy() - }) -}) - -describe('isUpdateAvailable', () => { - it('should return false if no new version is available', async () => { - await expect(plugin.isUpdateAvailable()).resolves.toBe(false) - }) - - it('should return the latest version number if a new version is available', async () => { - writeFileSync(manifest, JSON.stringify({ - name: testPluginName, - activationPoints: [], - version: '0.0.3', - main, - }), 'utf8') - - await expect(plugin.isUpdateAvailable()).resolves.toBe('0.0.3') - }) -}) - -describe('setActive', () => { - it('should set the plugin to be active', () => { - plugin.setActive(true) - expect(plugin.active).toBeTruthy() - }) - - it('should execute callbacks subscribed to this plugin, providing the plugin as a parameter', () => { - let res = false - plugin.subscribe('test', () => res = true) - plugin.setActive(true) - - expect(res).toBeTruthy() - }) -}) - -describe('uninstall', () => { - let subscription = false - beforeAll(async () => { - plugin.subscribe('test', () => subscription = true) - await plugin.uninstall() - }) - - it('should remove the installed plugin from the plugins folder', () => { - expect(existsSync(join(pluginsDir, testPluginName))).toBe(false) - }) - - it('should execute callbacks subscribed to this plugin, providing the plugin as a parameter', () => { - expect(subscription).toBeTruthy() - }) -}) diff --git a/electron/core/plugin-manager/pluginMgr/globals.js b/electron/core/plugin-manager/pluginMgr/globals.js deleted file mode 100644 index a0fc3718d..000000000 --- a/electron/core/plugin-manager/pluginMgr/globals.js +++ /dev/null @@ -1,57 +0,0 @@ -import { existsSync, mkdirSync, writeFileSync } from "fs" -import { join, resolve } from "path" - -export let pluginsPath = null - -/** - * @private - * Set path to plugins directory and create the directory if it does not exist. - * @param {string} plgPath path to plugins directory - */ -export function setPluginsPath(plgPath) { - // Create folder if it does not exist - let plgDir - try { - plgDir = resolve(plgPath) - if (plgDir.length < 2) throw new Error() - - if (!existsSync(plgDir)) mkdirSync(plgDir) - - const pluginsJson = join(plgDir, 'plugins.json') - if (!existsSync(pluginsJson)) writeFileSync(pluginsJson, '{}', 'utf8') - - pluginsPath = plgDir - - } catch (error) { - throw new Error('Invalid path provided to the plugins folder') - } - -} - -/** -* @private - * Get the path to the plugins.json file. - * @returns location of plugins.json - */ -export function getPluginsFile() { return join(pluginsPath, 'plugins.json') } - - -export let confirmInstall = function () { - return new Error( - 'The facade.confirmInstall callback needs to be set in when initializing Pluggable Electron in the main process.' - ) -} - -/** - * @private - * Set callback to use as confirmInstall. - * @param {confirmInstall} cb Callback - */ -export function setConfirmInstall(cb) { confirmInstall = cb } - -/** - * This function is executed when plugins are installed to verify that the user indeed wants to install the plugin. - * @callback confirmInstall - * @param {Array.} plg The specifiers used to locate the packages (from NPM or local file) - * @returns {Promise} Whether to proceed with the plugin installation - */ diff --git a/electron/core/plugin-manager/pluginMgr/index.test.js b/electron/core/plugin-manager/pluginMgr/index.test.js deleted file mode 100644 index 37055db62..000000000 --- a/electron/core/plugin-manager/pluginMgr/index.test.js +++ /dev/null @@ -1,150 +0,0 @@ -import { usePlugins, getStore, init } from './index' -import { installPlugins, getPlugin, getAllPlugins, getActivePlugins, addPlugin, removePlugin } from './store' -import Plugin from './Plugin' -import { existsSync, rmSync, mkdirSync, writeFileSync } from 'fs' -import { join } from 'path' -import { protocol } from 'electron' - -// Set up variables for test folders and test plugins. -const pluginDir = './testPlugins' -const registeredPluginName = 'registered-plugin' -const demoPlugin = { - origin: ".\\demo-plugin\\demo-plugin-1.5.0.tgz", - installOptions: { - version: false, - fullMetadata: false - }, - name: "demoPlugin", - version: "1.5.0", - activationPoints: [ - "init" - ], - main: "index.js", - _active: true, - url: "plugin://demo-plugin/index.js" -} - -describe('before setting a plugin path', () => { - describe('getStore', () => { - it('should throw an error if called without a plugin path set', () => { - expect(() => getStore()).toThrowError('The plugin path has not yet been set up. Please run usePlugins before accessing the store') - }) - }) - - describe('usePlugins', () => { - it('should throw an error if called without a plugin path whilst no plugin path is set', () => { - expect(() => usePlugins()).toThrowError('A path to the plugins folder is required to use Pluggable Electron') - }) - - it('should throw an error if called with an invalid plugin path', () => { - expect(() => usePlugins('http://notsupported')).toThrowError('Invalid path provided to the plugins folder') - }) - - it('should create the plugin path if it does not yet exist', () => { - // Execute usePlugins with a folder that does not exist - const newPluginDir = './test-new-plugins' - usePlugins(newPluginDir) - expect(existsSync(newPluginDir)).toBe(true) - - // Remove created folder to clean up - rmSync(newPluginDir, { recursive: true }) - }) - }) -}) - -describe('after setting a plugin path', () => { - let pm - - beforeAll(() => { - // Create folders to contain plugins - mkdirSync(pluginDir) - - // Create initial - writeFileSync(join(pluginDir, 'plugins.json'), JSON.stringify({ demoPlugin }), 'utf8') - - // Register a plugin before using plugins - const registeredPLugin = new Plugin(registeredPluginName) - registeredPLugin.name = registeredPluginName - addPlugin(registeredPLugin, false) - - // Load plugins - pm = usePlugins(pluginDir) - }) - - afterAll(() => { - rmSync(pluginDir, { recursive: true }) - }) - - describe('getStore', () => { - it('should return the plugin lifecycle functions if no plugin path is provided', () => { - expect(getStore()).toEqual({ - installPlugins, - getPlugin, - getAllPlugins, - getActivePlugins, - removePlugin, - }) - }) - }) - - describe('usePlugins', () => { - it('should return the plugin lifecycle functions if a plugin path is provided', () => { - expect(pm).toEqual({ - installPlugins, - getPlugin, - getAllPlugins, - getActivePlugins, - removePlugin, - }) - }) - - it('should load the plugins defined in plugins.json in the provided plugins folder if a plugin path is provided', () => { - expect(getPlugin('demoPlugin')).toEqual(demoPlugin) - }) - - it('should unregister any registered plugins before registering the new ones if a plugin path is provided', () => { - expect(() => getPlugin(registeredPluginName)).toThrowError(`Plugin ${registeredPluginName} does not exist`) - }) - }) -}) - -describe('init', () => { - // Enabling the facade and registering the confirm install function is tested with the router. - let pm - - beforeAll(() => { - // Create test plugins folder - mkdirSync(pluginDir) - - // Initialize Pluggable Electron without a plugin folder - pm = init({ confirmInstall: () => true }) - }) - - afterAll(() => { - // Remove test plugins folder - rmSync(pluginDir, { recursive: true }) - }) - - it('should make the plugin files available through the plugin protocol', async () => { - expect(protocol.isProtocolRegistered('plugin')).toBeTruthy() - }) - - it('should return an empty object if no plugin path is provided', () => { - expect(pm).toEqual({}) - }) - - it('should return the plugin lifecycle functions if a plugin path is provided', () => { - pm = init({ - confirmInstall: () => true, - pluginsPath: pluginDir, - }) - - expect(pm).toEqual({ - installPlugins, - getPlugin, - getAllPlugins, - getActivePlugins, - removePlugin, - }) - }) -}) \ No newline at end of file diff --git a/electron/core/plugin-manager/pluginMgr/router.js b/electron/core/plugin-manager/pluginMgr/router.js deleted file mode 100644 index 17e5ba2f5..000000000 --- a/electron/core/plugin-manager/pluginMgr/router.js +++ /dev/null @@ -1,91 +0,0 @@ -import { ipcMain, webContents } from "electron" - -import { getPlugin, getActivePlugins, installPlugins, removePlugin, getAllPlugins } from "./store" -import { pluginsPath, confirmInstall } from './globals' - -// Throw an error if pluginsPath has not yet been provided by usePlugins. -const checkPluginsPath = () => { - if (!pluginsPath) throw Error('Path to plugins folder has not yet been set up.') -} -let active = false -/** - * Provide the renderer process access to the plugins. - **/ -export default function () { - if (active) return - // Register IPC route to install a plugin - ipcMain.handle('pluggable:install', async (e, plugins) => { - checkPluginsPath() - - // Validate install request from backend for security. - const specs = plugins.map(plg => typeof plg === 'object' ? plg.specifier : plg) - const conf = await confirmInstall(specs) - if (!conf) return { cancelled: true } - - // Install and activate all provided plugins - const installed = await installPlugins(plugins) - return JSON.parse(JSON.stringify(installed)) - }) - - // Register IPC route to uninstall a plugin - ipcMain.handle('pluggable:uninstall', async (e, plugins, reload) => { - checkPluginsPath() - - // Uninstall all provided plugins - for (const plg of plugins) { - const plugin = getPlugin(plg) - await plugin.uninstall() - removePlugin(plugin.name) - } - - // Reload all renderer pages if needed - reload && webContents.getAllWebContents().forEach(wc => wc.reload()) - return true - }) - - // Register IPC route to update a plugin - ipcMain.handle('pluggable:update', (e, plugins, reload) => { - checkPluginsPath() - - // Update all provided plugins - let updated = [] - for (const plg of plugins) { - const plugin = getPlugin(plg) - const res = plugin.update() - if (res) updated.push(plugin) - } - - // Reload all renderer pages if needed - if (updated.length && reload) webContents.getAllWebContents().forEach(wc => wc.reload()) - - return JSON.parse(JSON.stringify(updated)) - }) - - // Register IPC route to check if updates are available for a plugin - ipcMain.handle('pluggable:updatesAvailable', (e, names) => { - checkPluginsPath() - - const plugins = names ? names.map(name => getPlugin(name)) : getAllPlugins() - - const updates = {} - for (const plugin of plugins) { - updates[plugin.name] = plugin.isUpdateAvailable() - } - return updates - }) - - // Register IPC route to get the list of active plugins - ipcMain.handle('pluggable:getActivePlugins', () => { - checkPluginsPath() - return JSON.parse(JSON.stringify(getActivePlugins())) - }) - - // Register IPC route to toggle the active state of a plugin - ipcMain.handle('pluggable:togglePluginActive', (e, plg, active) => { - checkPluginsPath() - const plugin = getPlugin(plg) - return JSON.parse(JSON.stringify(plugin.setActive(active))) - }) - - active = true -} diff --git a/electron/core/plugin-manager/pluginMgr/store.test.js b/electron/core/plugin-manager/pluginMgr/store.test.js deleted file mode 100644 index e99cb7c7d..000000000 --- a/electron/core/plugin-manager/pluginMgr/store.test.js +++ /dev/null @@ -1,108 +0,0 @@ -import { getActivePlugins, getAllPlugins, getPlugin, installPlugins } from './store' -import { init } from "." -import { join } from 'path' -import Plugin from "./Plugin" -import { mkdirSync, writeFileSync, rmSync } from "fs" - -// Temporary directory to install plugins to -const pluginsDir = './testPlugins' - -// Temporary directory containing the active plugin to install -const activePluginDir = './activePluginSrc' -const activePluginName = 'active-plugin' -const activeManifest = join(activePluginDir, 'package.json') - -// Temporary directory containing the inactive plugin to install -const inactivePluginDir = './inactivePluginSrc' -const inactivePluginName = 'inactive-plugin' -const inactiveManifest = join(inactivePluginDir, 'package.json') - -// Mock name for the entry file in the plugins -const main = 'index' - -/** @type Array. */ -let activePlugins -/** @type Array. */ -let inactivePlugins - -beforeAll(async () => { - // Initialize pluggable Electron - init({ - confirmInstall: () => true, - pluginsPath: pluginsDir, - }) - - // Create active plugin - mkdirSync(activePluginDir) - writeFileSync(activeManifest, JSON.stringify({ - name: activePluginName, - activationPoints: [], - main, - }), 'utf8') - - // Create active plugin - mkdirSync(inactivePluginDir) - writeFileSync(inactiveManifest, JSON.stringify({ - name: inactivePluginName, - activationPoints: [], - main, - }), 'utf8') - - // Install plugins - activePlugins = await installPlugins([activePluginDir], true) - activePlugins[0].setActive(true) - inactivePlugins = await installPlugins([{ - specifier: inactivePluginDir, - activate: false - }], true) -}) - -afterAll(() => { - // Remove all test files and folders - rmSync(pluginsDir, { recursive: true }) - rmSync(activePluginDir, { recursive: true }) - rmSync(inactivePluginDir, { recursive: true }) -}) - -describe('installPlugins', () => { - it('should create a new plugin found at the given location and return it if store is false', async () => { - const res = await installPlugins([activePluginDir], false) - - expect(res[0]).toBeInstanceOf(Plugin) - }) - - it('should create a new plugin found at the given location and register it if store is true', () => { - expect(activePlugins[0]).toBeInstanceOf(Plugin) - expect(getPlugin(activePluginName)).toBe(activePlugins[0]) - }) - - it('should activate the installed plugin by default', () => { - expect(getPlugin(activePluginName).active).toBe(true) - }) - - it('should set plugin to inactive if activate is set to false in the install options', async () => { - expect(inactivePlugins[0].active).toBe(false) - }) -}) - -describe('getPlugin', () => { - it('should return the plugin with the given name if it is registered', () => { - expect(getPlugin(activePluginName)).toBeInstanceOf(Plugin) - }) - - it('should return an error if the plugin with the given name is not registered', () => { - expect(() => getPlugin('wrongName')).toThrowError('Plugin wrongName does not exist') - }) -}) - -describe('getAllPlugins', () => { - it('should return a list of all registered plugins', () => { - expect(getAllPlugins()).toEqual([activePlugins[0], inactivePlugins[0]]) - }) -}) - -describe('getActivePlugins', () => { - it('should return a list of all and only the registered plugins that are active', () => { - expect(getActivePlugins()).toEqual(activePlugins) - }) -}) \ No newline at end of file diff --git a/electron/core/plugin-manager/facade/index.js b/electron/core/plugin/facade.ts similarity index 74% rename from electron/core/plugin-manager/facade/index.js rename to electron/core/plugin/facade.ts index 27a2468c7..bd1089109 100644 --- a/electron/core/plugin-manager/facade/index.js +++ b/electron/core/plugin/facade.ts @@ -1,23 +1,23 @@ const { ipcRenderer, contextBridge } = require("electron"); -function useFacade() { +export function useFacade() { const interfaces = { - install(plugins) { + install(plugins: any[]) { return ipcRenderer.invoke("pluggable:install", plugins); }, - uninstall(plugins, reload) { + uninstall(plugins: any[], reload: boolean) { return ipcRenderer.invoke("pluggable:uninstall", plugins, reload); }, getActive() { return ipcRenderer.invoke("pluggable:getActivePlugins"); }, - update(plugins, reload) { + update(plugins: any[], reload: boolean) { return ipcRenderer.invoke("pluggable:update", plugins, reload); }, - updatesAvailable(plugin) { + updatesAvailable(plugin: any) { return ipcRenderer.invoke("pluggable:updatesAvailable", plugin); }, - toggleActive(plugin, active) { + toggleActive(plugin: any, active: boolean) { return ipcRenderer.invoke("pluggable:togglePluginActive", plugin, active); }, }; @@ -28,5 +28,3 @@ function useFacade() { return interfaces; } - -module.exports = useFacade; diff --git a/electron/core/plugin/globals.ts b/electron/core/plugin/globals.ts new file mode 100644 index 000000000..69df7925c --- /dev/null +++ b/electron/core/plugin/globals.ts @@ -0,0 +1,36 @@ +import { existsSync, mkdirSync, writeFileSync } from "fs"; +import { join, resolve } from "path"; + +export let pluginsPath: string | undefined = undefined; + +/** + * @private + * Set path to plugins directory and create the directory if it does not exist. + * @param {string} plgPath path to plugins directory + */ +export function setPluginsPath(plgPath: string) { + // Create folder if it does not exist + let plgDir; + try { + plgDir = resolve(plgPath); + if (plgDir.length < 2) throw new Error(); + + if (!existsSync(plgDir)) mkdirSync(plgDir); + + const pluginsJson = join(plgDir, "plugins.json"); + if (!existsSync(pluginsJson)) writeFileSync(pluginsJson, "{}", "utf8"); + + pluginsPath = plgDir; + } catch (error) { + throw new Error("Invalid path provided to the plugins folder"); + } +} + +/** + * @private + * Get the path to the plugins.json file. + * @returns location of plugins.json + */ +export function getPluginsFile() { + return join(pluginsPath ?? "", "plugins.json"); +} \ No newline at end of file diff --git a/electron/core/plugin-manager/pluginMgr/index.js b/electron/core/plugin/index.ts similarity index 54% rename from electron/core/plugin-manager/pluginMgr/index.js rename to electron/core/plugin/index.ts index ae7bff759..e8c64747b 100644 --- a/electron/core/plugin-manager/pluginMgr/index.js +++ b/electron/core/plugin/index.ts @@ -1,40 +1,52 @@ -import { readFileSync } from "fs" -import { protocol } from 'electron' -import { normalize } from "path" +import { readFileSync } from "fs"; +import { protocol } from "electron"; +import { normalize } from "path"; -import Plugin from "./Plugin" -import { getAllPlugins, removePlugin, persistPlugins, installPlugins, getPlugin, getActivePlugins, addPlugin } from "./store" -import { pluginsPath as storedPluginsPath, setPluginsPath, getPluginsFile, setConfirmInstall } from './globals' -import router from "./router" +import Plugin from "./plugin"; +import { + getAllPlugins, + removePlugin, + persistPlugins, + installPlugins, + getPlugin, + getActivePlugins, + addPlugin, +} from "./store"; +import { + pluginsPath as storedPluginsPath, + setPluginsPath, + getPluginsFile, +} from "./globals"; +import router from "./router"; /** * Sets up the required communication between the main and renderer processes. * Additionally sets the plugins up using {@link usePlugins} if a pluginsPath is provided. * @param {Object} options configuration for setting up the renderer facade. - * @param {confirmInstall} [options.confirmInstall] Function to validate that a plugin should be installed. + * @param {confirmInstall} [options.confirmInstall] Function to validate that a plugin should be installed. * @param {Boolean} [options.useFacade=true] Whether to make a facade to the plugins available in the renderer. * @param {string} [options.pluginsPath] Optional path to the plugins folder. * @returns {pluginManager|Object} A set of functions used to manage the plugin lifecycle if usePlugins is provided. * @function */ -export function init(options) { - if (!Object.prototype.hasOwnProperty.call(options, 'useFacade') || options.useFacade) { - // Store the confirmInstall function - setConfirmInstall(options.confirmInstall) +export function init(options: any) { + if ( + !Object.prototype.hasOwnProperty.call(options, "useFacade") || + options.useFacade + ) { // Enable IPC to be used by the facade - router() + router(); } // Create plugins protocol to serve plugins to renderer - registerPluginProtocol() + registerPluginProtocol(); // perform full setup if pluginsPath is provided if (options.pluginsPath) { - return usePlugins(options.pluginsPath) + return usePlugins(options.pluginsPath); } - return {} - + return {}; } /** @@ -43,11 +55,11 @@ export function init(options) { * @returns {boolean} Whether the protocol registration was successful */ function registerPluginProtocol() { - return protocol.registerFileProtocol('plugin', (request, callback) => { - const entry = request.url.substr(8) - const url = normalize(storedPluginsPath + entry) - callback({ path: url }) - }) + return protocol.registerFileProtocol("plugin", (request, callback) => { + const entry = request.url.substr(8); + const url = normalize(storedPluginsPath + entry); + callback({ path: url }); + }); } /** @@ -56,34 +68,38 @@ function registerPluginProtocol() { * @param {string} pluginsPath Path to the plugins folder. Required if not yet set up. * @returns {pluginManager} A set of functions used to manage the plugin lifecycle. */ -export function usePlugins(pluginsPath) { - if (!pluginsPath) throw Error('A path to the plugins folder is required to use Pluggable Electron') +export function usePlugins(pluginsPath: string) { + if (!pluginsPath) + throw Error( + "A path to the plugins folder is required to use Pluggable Electron" + ); // Store the path to the plugins folder - setPluginsPath(pluginsPath) + setPluginsPath(pluginsPath); // Remove any registered plugins for (const plugin of getAllPlugins()) { - removePlugin(plugin.name, false) + if (plugin.name) removePlugin(plugin.name, false); } // Read plugin list from plugins folder - const plugins = JSON.parse(readFileSync(getPluginsFile())) + const plugins = JSON.parse(readFileSync(getPluginsFile(), "utf-8")); try { // Create and store a Plugin instance for each plugin in list for (const p in plugins) { - loadPlugin(plugins[p]) + loadPlugin(plugins[p]); } - persistPlugins() - + persistPlugins(); } catch (error) { // Throw meaningful error if plugin loading fails - throw new Error('Could not successfully rebuild list of installed plugins.\n' - + error - + '\nPlease check the plugins.json file in the plugins folder.') + throw new Error( + "Could not successfully rebuild list of installed plugins.\n" + + error + + "\nPlease check the plugins.json file in the plugins folder." + ); } // Return the plugin lifecycle functions - return getStore() + return getStore(); } /** @@ -92,16 +108,24 @@ export function usePlugins(pluginsPath) { * @private * @param {Object} plg Plugin info */ -function loadPlugin(plg) { +function loadPlugin(plg: any) { // Create new plugin, populate it with plg details and save it to the store - const plugin = new Plugin() + const plugin = new Plugin(); for (const key in plg) { - plugin[key] = plg[key] + if (Object.prototype.hasOwnProperty.call(plg, key)) { + // Use Object.defineProperty to set the properties as writable + Object.defineProperty(plugin, key, { + value: plg[key], + writable: true, + enumerable: true, + configurable: true, + }); + } } - addPlugin(plugin, false) - plugin.subscribe('pe-persist', persistPlugins) + addPlugin(plugin, false); + plugin.subscribe("pe-persist", persistPlugins); } /** @@ -110,7 +134,9 @@ function loadPlugin(plg) { */ export function getStore() { if (!storedPluginsPath) { - throw new Error('The plugin path has not yet been set up. Please run usePlugins before accessing the store') + throw new Error( + "The plugin path has not yet been set up. Please run usePlugins before accessing the store" + ); } return { @@ -119,5 +145,5 @@ export function getStore() { getAllPlugins, getActivePlugins, removePlugin, - } + }; } diff --git a/electron/core/plugin-manager/pluginMgr/Plugin.js b/electron/core/plugin/plugin.ts similarity index 59% rename from electron/core/plugin-manager/pluginMgr/Plugin.js rename to electron/core/plugin/plugin.ts index 093e4015f..f0fc073d7 100644 --- a/electron/core/plugin-manager/pluginMgr/Plugin.js +++ b/electron/core/plugin/plugin.ts @@ -1,11 +1,11 @@ -import { rmdir } from "fs/promises" -import { resolve, join } from "path" -import { manifest, extract } from "pacote" -import Arborist from '@npmcli/arborist' +import { rmdir } from "fs/promises"; +import { resolve, join } from "path"; +import { manifest, extract } from "pacote"; +import * as Arborist from "@npmcli/arborist"; -import { pluginsPath } from "./globals" +import { pluginsPath } from "./globals"; -/** +/** * An NPM package that can be used as a Pluggable Electron plugin. * Used to hold all the information and functions necessary to handle the plugin lifecycle. */ @@ -21,30 +21,39 @@ class Plugin { * @property {string} description The description of plugin as defined in the manifest. * @property {string} icon The icon of plugin as defined in the manifest. */ + origin?: string; + installOptions: any; + name?: string; + url?: string; + version?: string; + activationPoints?: Array; + main?: string; + description?: string; + icon?: string; /** @private */ - _active = false + _active = false; /** * @private * @property {Object.} #listeners A list of callbacks to be executed when the Plugin is updated. */ - #listeners = {} + listeners: Record void> = {}; /** * Set installOptions with defaults for options that have not been provided. * @param {string} [origin] Original specification provided to fetch the package. * @param {Object} [options] Options provided to pacote when fetching the manifest. */ - constructor(origin, options = {}) { + constructor(origin?: string, options = {}) { const defaultOpts = { version: false, fullMetadata: false, - Arborist - } + Arborist, + }; - this.origin = origin - this.installOptions = { ...defaultOpts, ...options } + this.origin = origin; + this.installOptions = { ...defaultOpts, ...options }; } /** @@ -52,7 +61,10 @@ class Plugin { * @type {string} */ get specifier() { - return this.origin + (this.installOptions.version ? '@' + this.installOptions.version : '') + return ( + this.origin + + (this.installOptions.version ? "@" + this.installOptions.version : "") + ); } /** @@ -60,31 +72,34 @@ class Plugin { * @type {boolean} */ get active() { - return this._active + return this._active; } /** * Set Package details based on it's manifest - * @returns {Promise.} Resolves to true when the action completed + * @returns {Promise.} Resolves to true when the action completed */ - async #getManifest() { + async getManifest() { // Get the package's manifest (package.json object) try { - const mnf = await manifest(this.specifier, this.installOptions) + const mnf = await manifest(this.specifier, this.installOptions); // set the Package properties based on the it's manifest - this.name = mnf.name - this.version = mnf.version - this.activationPoints = mnf.activationPoints || null - this.main = mnf.main - this.description = mnf.description - this.icon = mnf.icon - + this.name = mnf.name; + this.version = mnf.version; + this.activationPoints = mnf.activationPoints + ? (mnf.activationPoints as string[]) + : undefined; + this.main = mnf.main; + this.description = mnf.description; + this.icon = mnf.icon as any; } catch (error) { - throw new Error(`Package ${this.origin} does not contain a valid manifest: ${error}`) + throw new Error( + `Package ${this.origin} does not contain a valid manifest: ${error}` + ); } - return true + return true; } /** @@ -95,26 +110,29 @@ class Plugin { async _install() { try { // import the manifest details - await this.#getManifest() + await this.getManifest(); // Install the package in a child folder of the given folder - await extract(this.specifier, join(pluginsPath, this.name), this.installOptions) + await extract( + this.specifier, + join(pluginsPath ?? "", this.name ?? ""), + this.installOptions + ); if (!Array.isArray(this.activationPoints)) - throw new Error('The plugin does not contain any activation points') + throw new Error("The plugin does not contain any activation points"); // Set the url using the custom plugins protocol - this.url = `plugin://${this.name}/${this.main}` - - this.#emitUpdate() + this.url = `plugin://${this.name}/${this.main}`; + this.emitUpdate(); } catch (err) { // Ensure the plugin is not stored and the folder is removed if the installation fails - this.setActive(false) - throw err + this.setActive(false); + throw err; } - return [this] + return [this]; } /** @@ -122,24 +140,24 @@ class Plugin { * @param {string} name name of the callback to register * @param {callback} cb The function to execute on update */ - subscribe(name, cb) { - this.#listeners[name] = cb + subscribe(name: string, cb: () => void) { + this.listeners[name] = cb; } /** * Remove subscription * @param {string} name name of the callback to remove */ - unsubscribe(name) { - delete this.#listeners[name] + unsubscribe(name: string) { + delete this.listeners[name]; } /** * Execute listeners */ - #emitUpdate() { - for (const cb in this.#listeners) { - this.#listeners[cb].call(null, this) + emitUpdate() { + for (const cb in this.listeners) { + this.listeners[cb].call(null, this); } } @@ -149,13 +167,13 @@ class Plugin { * @returns {boolean} Whether an update was performed. */ async update(version = false) { - if (this.isUpdateAvailable()) { - this.installOptions.version = version - await this._install(false) - return true + if (await this.isUpdateAvailable()) { + this.installOptions.version = version; + await this._install(); + return true; } - return false + return false; } /** @@ -163,19 +181,21 @@ class Plugin { * @returns the latest available version if a new version is available or false if not. */ async isUpdateAvailable() { - const mnf = await manifest(this.origin) - return mnf.version !== this.version ? mnf.version : false + if (this.origin) { + const mnf = await manifest(this.origin); + return mnf.version !== this.version ? mnf.version : false; + } } /** * Remove plugin and refresh renderers. - * @returns {Promise} + * @returns {Promise} */ async uninstall() { - const plgPath = resolve(pluginsPath, this.name) - await rmdir(plgPath, { recursive: true }) + const plgPath = resolve(pluginsPath ?? "", this.name ?? ""); + await rmdir(plgPath, { recursive: true }); - this.#emitUpdate() + this.emitUpdate(); } /** @@ -183,11 +203,11 @@ class Plugin { * @param {boolean} active State to set _active to * @returns {Plugin} This plugin */ - setActive(active) { - this._active = active - this.#emitUpdate() - return this + setActive(active: boolean) { + this._active = active; + this.emitUpdate(); + return this; } } -export default Plugin +export default Plugin; diff --git a/electron/core/plugin/router.ts b/electron/core/plugin/router.ts new file mode 100644 index 000000000..09c79485b --- /dev/null +++ b/electron/core/plugin/router.ts @@ -0,0 +1,97 @@ +import { ipcMain, webContents } from "electron"; + +import { + getPlugin, + getActivePlugins, + installPlugins, + removePlugin, + getAllPlugins, +} from "./store"; +import { pluginsPath } from "./globals"; +import Plugin from "./plugin"; + +// Throw an error if pluginsPath has not yet been provided by usePlugins. +const checkPluginsPath = () => { + if (!pluginsPath) + throw Error("Path to plugins folder has not yet been set up."); +}; +let active = false; +/** + * Provide the renderer process access to the plugins. + **/ +export default function () { + if (active) return; + // Register IPC route to install a plugin + ipcMain.handle("pluggable:install", async (e, plugins) => { + checkPluginsPath(); + + // Install and activate all provided plugins + const installed = await installPlugins(plugins); + return JSON.parse(JSON.stringify(installed)); + }); + + // Register IPC route to uninstall a plugin + ipcMain.handle("pluggable:uninstall", async (e, plugins, reload) => { + checkPluginsPath(); + + // Uninstall all provided plugins + for (const plg of plugins) { + const plugin = getPlugin(plg); + await plugin.uninstall(); + if (plugin.name) removePlugin(plugin.name); + } + + // Reload all renderer pages if needed + reload && webContents.getAllWebContents().forEach((wc) => wc.reload()); + return true; + }); + + // Register IPC route to update a plugin + ipcMain.handle("pluggable:update", async (e, plugins, reload) => { + checkPluginsPath(); + + // Update all provided plugins + const updated: Plugin[] = []; + for (const plg of plugins) { + const plugin = getPlugin(plg); + const res = await plugin.update(); + if (res) updated.push(plugin); + } + + // Reload all renderer pages if needed + if (updated.length && reload) + webContents.getAllWebContents().forEach((wc) => wc.reload()); + + return JSON.parse(JSON.stringify(updated)); + }); + + // Register IPC route to check if updates are available for a plugin + ipcMain.handle("pluggable:updatesAvailable", (e, names) => { + checkPluginsPath(); + + const plugins = names + ? names.map((name: string) => getPlugin(name)) + : getAllPlugins(); + + const updates: Record = {}; + for (const plugin of plugins) { + updates[plugin.name] = plugin.isUpdateAvailable(); + } + return updates; + }); + + // Register IPC route to get the list of active plugins + ipcMain.handle("pluggable:getActivePlugins", () => { + checkPluginsPath(); + return JSON.parse(JSON.stringify(getActivePlugins())); + }); + + // Register IPC route to toggle the active state of a plugin + ipcMain.handle("pluggable:togglePluginActive", (e, plg, active) => { + checkPluginsPath(); + const plugin = getPlugin(plg); + return JSON.parse(JSON.stringify(plugin.setActive(active))); + }); + + active = true; +} diff --git a/electron/core/plugin-manager/pluginMgr/store.js b/electron/core/plugin/store.ts similarity index 65% rename from electron/core/plugin-manager/pluginMgr/store.js rename to electron/core/plugin/store.ts index 16a647930..cfd25e5ca 100644 --- a/electron/core/plugin-manager/pluginMgr/store.js +++ b/electron/core/plugin/store.ts @@ -8,9 +8,9 @@ * @prop {removePlugin} removePlugin */ -import { writeFileSync } from "fs" -import Plugin from "./Plugin" -import { getPluginsFile } from './globals' +import { writeFileSync } from "fs"; +import Plugin from "./plugin"; +import { getPluginsFile } from "./globals"; /** * @module store @@ -21,7 +21,7 @@ import { getPluginsFile } from './globals' * Register of installed plugins * @type {Object.} plugin - List of installed plugins */ -const plugins = {} +const plugins: Record = {}; /** * Get a plugin from the stored plugins. @@ -29,12 +29,12 @@ const plugins = {} * @returns {Plugin} Retrieved plugin * @alias pluginManager.getPlugin */ -export function getPlugin(name) { +export function getPlugin(name: string) { if (!Object.prototype.hasOwnProperty.call(plugins, name)) { - throw new Error(`Plugin ${name} does not exist`) + throw new Error(`Plugin ${name} does not exist`); } - return plugins[name] + return plugins[name]; } /** @@ -42,7 +42,9 @@ export function getPlugin(name) { * @returns {Array.} All plugin objects * @alias pluginManager.getAllPlugins */ -export function getAllPlugins() { return Object.values(plugins) } +export function getAllPlugins() { + return Object.values(plugins); +} /** * Get list of active plugin objects. @@ -50,7 +52,7 @@ export function getAllPlugins() { return Object.values(plugins) } * @alias pluginManager.getActivePlugins */ export function getActivePlugins() { - return Object.values(plugins).filter(plugin => plugin.active) + return Object.values(plugins).filter((plugin) => plugin.active); } /** @@ -60,10 +62,10 @@ export function getActivePlugins() { * @returns {boolean} Whether the delete was successful * @alias pluginManager.removePlugin */ -export function removePlugin(name, persist = true) { - const del = delete plugins[name] - if (persist) persistPlugins() - return del +export function removePlugin(name: string, persist = true) { + const del = delete plugins[name]; + if (persist) persistPlugins(); + return del; } /** @@ -72,11 +74,11 @@ export function removePlugin(name, persist = true) { * @param {boolean} persist Whether to save the changes to plugins to file * @returns {void} */ -export function addPlugin(plugin, persist = true) { - plugins[plugin.name] = plugin +export function addPlugin(plugin: Plugin, persist = true) { + if (plugin.name) plugins[plugin.name] = plugin; if (persist) { - persistPlugins() - plugin.subscribe('pe-persist', persistPlugins) + persistPlugins(); + plugin.subscribe("pe-persist", persistPlugins); } } @@ -85,11 +87,11 @@ export function addPlugin(plugin, persist = true) { * @returns {void} */ export function persistPlugins() { - const persistData = {} + const persistData: Record = {}; for (const name in plugins) { - persistData[name] = plugins[name] + persistData[name] = plugins[name]; } - writeFileSync(getPluginsFile(), JSON.stringify(persistData), 'utf8') + writeFileSync(getPluginsFile(), JSON.stringify(persistData), "utf8"); } /** @@ -99,26 +101,26 @@ export function persistPlugins() { * @returns {Promise.>} New plugin * @alias pluginManager.installPlugins */ -export async function installPlugins(plugins, store = true) { - const installed = [] +export async function installPlugins(plugins: any, store = true) { + const installed: Plugin[] = []; for (const plg of plugins) { // Set install options and activation based on input type - const isObject = typeof plg === 'object' - const spec = isObject ? [plg.specifier, plg] : [plg] - const activate = isObject ? plg.activate !== false : true + const isObject = typeof plg === "object"; + const spec = isObject ? [plg.specifier, plg] : [plg]; + const activate = isObject ? plg.activate !== false : true; // Install and possibly activate plugin - const plugin = new Plugin(...spec) - await plugin._install() - if (activate) plugin.setActive(true) + const plugin = new Plugin(...spec); + await plugin._install(); + if (activate) plugin.setActive(true); // Add plugin to store if needed - if (store) addPlugin(plugin) - installed.push(plugin) + if (store) addPlugin(plugin); + installed.push(plugin); } // Return list of all installed plugins - return installed + return installed; } /** @@ -126,4 +128,4 @@ export async function installPlugins(plugins, store = true) { * options used to install the plugin with some extra options. * @param {string} specifier the NPM specifier that identifies the package. * @param {boolean} [activate] Whether this plugin should be activated after installation. Defaults to true. - */ \ No newline at end of file + */ diff --git a/electron/main.ts b/electron/main.ts index 3a2da0aca..b11963e67 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -9,7 +9,7 @@ import { import { readdirSync, writeFileSync } from "fs"; import { resolve, join, extname } from "path"; import { rmdir, unlink, createWriteStream } from "fs"; -import { init } from "./core/plugin-manager/pluginMgr"; +import { init } from "./core/plugin/index"; import { setupMenu } from "./utils/menu"; import { dispose } from "./utils/disposable"; @@ -19,7 +19,8 @@ const progress = require("request-progress"); const { autoUpdater } = require("electron-updater"); const Store = require("electron-store"); -const requiredModules: Record = {}; +let requiredModules: Record = {}; +const networkRequests: Record = {}; let mainWindow: BrowserWindow | undefined = undefined; app @@ -39,32 +40,19 @@ app }); app.on("window-all-closed", () => { - dispose(requiredModules); + clearImportedModules(); app.quit(); }); app.on("quit", () => { - dispose(requiredModules); + clearImportedModules(); app.quit(); }); -ipcMain.handle("setNativeThemeLight", () => { - nativeTheme.themeSource = "light"; -}); - -ipcMain.handle("setNativeThemeDark", () => { - nativeTheme.themeSource = "dark"; -}); - -ipcMain.handle("setNativeThemeSystem", () => { - nativeTheme.themeSource = "system"; -}); - function createMainWindow() { mainWindow = new BrowserWindow({ width: 1200, height: 800, - frame: false, show: false, trafficLightPosition: { x: 16, @@ -130,13 +118,39 @@ function handleAppUpdates() { }); autoUpdater.autoDownload = false; autoUpdater.autoInstallOnAppQuit = true; - autoUpdater.checkForUpdates(); + if (process.env.CI !== "e2e") { + autoUpdater.checkForUpdates(); + } } /** * Handles various IPC messages from the renderer process. */ function handleIPCs() { + /** + * Handles the "setNativeThemeLight" IPC message by setting the native theme source to "light". + * This will change the appearance of the app to the light theme. + */ + ipcMain.handle("setNativeThemeLight", () => { + nativeTheme.themeSource = "light"; + }); + + /** + * Handles the "setNativeThemeDark" IPC message by setting the native theme source to "dark". + * This will change the appearance of the app to the dark theme. + */ + ipcMain.handle("setNativeThemeDark", () => { + nativeTheme.themeSource = "dark"; + }); + + /** + * Handles the "setNativeThemeSystem" IPC message by setting the native theme source to "system". + * This will change the appearance of the app to match the system's current theme. + */ + ipcMain.handle("setNativeThemeSystem", () => { + nativeTheme.themeSource = "system"; + }); + /** * Invokes a function from a plugin module in main node process. * @param _event - The IPC event object. @@ -193,6 +207,15 @@ function handleIPCs() { return join(app.getPath("userData"), "plugins"); }); + /** + * Retrieves the path to the app data directory using the `coreAPI` object. + * If the `coreAPI` object is not available, the function returns `undefined`. + * @returns A Promise that resolves with the path to the app data directory, or `undefined` if the `coreAPI` object is not available. + */ + ipcMain.handle("appDataPath", async (_event) => { + return app.getPath("userData"); + }); + /** * Returns the version of the app. * @param _event - The IPC event object. @@ -202,6 +225,15 @@ function handleIPCs() { return app.getVersion(); }); + /** + * Handles the "openAppDirectory" IPC message by opening the app's user data directory. + * The `shell.openPath` method is used to open the directory in the user's default file explorer. + * @param _event - The IPC event object. + */ + ipcMain.handle("openAppDirectory", async (_event) => { + shell.openPath(app.getPath("userData")); + }); + /** * Opens a URL in the user's default browser. * @param _event - The IPC event object. @@ -217,7 +249,7 @@ function handleIPCs() { * @param url - The URL to reload. */ ipcMain.handle("relaunch", async (_event, url) => { - dispose(requiredModules); + clearImportedModules(); if (app.isPackaged) { app.relaunch(); @@ -246,7 +278,7 @@ function handleIPCs() { rmdir(fullPath, { recursive: true }, function (err) { if (err) console.log(err); - dispose(requiredModules); + clearImportedModules(); // just relaunch if packaged, should launch manually in development mode if (app.isPackaged) { @@ -300,8 +332,9 @@ function handleIPCs() { ipcMain.handle("downloadFile", async (_event, url, fileName) => { const userDataPath = app.getPath("userData"); const destination = resolve(userDataPath, fileName); + const rq = request(url); - progress(request(url), {}) + progress(rq, {}) .on("progress", function (state: any) { mainWindow?.webContents.send("FILE_DOWNLOAD_UPDATE", { ...state, @@ -313,13 +346,54 @@ function handleIPCs() { fileName, err, }); + networkRequests[fileName] = undefined; }) .on("end", function () { - mainWindow?.webContents.send("FILE_DOWNLOAD_COMPLETE", { - fileName, - }); + if (networkRequests[fileName]) { + mainWindow?.webContents.send("FILE_DOWNLOAD_COMPLETE", { + fileName, + }); + networkRequests[fileName] = undefined; + } else { + mainWindow?.webContents.send("FILE_DOWNLOAD_ERROR", { + fileName, + err: "Download cancelled", + }); + } }) .pipe(createWriteStream(destination)); + + networkRequests[fileName] = rq; + }); + + /** + * Handles the "pauseDownload" IPC message by pausing the download associated with the provided fileName. + * @param _event - The IPC event object. + * @param fileName - The name of the file being downloaded. + */ + ipcMain.handle("pauseDownload", async (_event, fileName) => { + networkRequests[fileName]?.pause(); + }); + + /** + * Handles the "resumeDownload" IPC message by resuming the download associated with the provided fileName. + * @param _event - The IPC event object. + * @param fileName - The name of the file being downloaded. + */ + ipcMain.handle("resumeDownload", async (_event, fileName) => { + networkRequests[fileName]?.resume(); + }); + + /** + * Handles the "abortDownload" IPC message by aborting the download associated with the provided fileName. + * The network request associated with the fileName is then removed from the networkRequests object. + * @param _event - The IPC event object. + * @param fileName - The name of the file being downloaded. + */ + ipcMain.handle("abortDownload", async (_event, fileName) => { + const rq = networkRequests[fileName]; + networkRequests[fileName] = undefined; + rq?.abort(); }); /** @@ -385,3 +459,8 @@ function setupPlugins() { pluginsPath: join(app.getPath("userData"), "plugins"), }); } + +function clearImportedModules() { + dispose(requiredModules); + requiredModules = {}; +} diff --git a/electron/package.json b/electron/package.json index 452c69ab8..da9b34e38 100644 --- a/electron/package.json +++ b/electron/package.json @@ -53,6 +53,10 @@ "test:e2e": "playwright test --workers=1", "dev": "tsc -p . && electron .", "build": "tsc -p . && electron-builder -p never -m", + "build:test": "tsc -p . && electron-builder --dir -p never -m", + "build:test-darwin": "tsc -p . && electron-builder -p never -m --x64 --arm64 --dir", + "build:test-win32": "tsc -p . && electron-builder -p never -w --dir", + "build:test-linux": "tsc -p . && electron-builder -p never -l --dir", "build:darwin": "tsc -p . && electron-builder -p never -m --x64 --arm64", "build:win32": "tsc -p . && electron-builder -p never -w", "build:linux": "tsc -p . && electron-builder -p never --linux deb", @@ -74,6 +78,8 @@ "devDependencies": { "@electron/notarize": "^2.1.0", "@playwright/test": "^1.38.1", + "@types/npmcli__arborist": "^5.6.4", + "@types/pacote": "^11.1.7", "@typescript-eslint/eslint-plugin": "^6.7.3", "@typescript-eslint/parser": "^6.7.3", "electron": "26.2.1", diff --git a/electron/preload.ts b/electron/preload.ts index dac5aef6f..e8fd723a5 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -1,6 +1,5 @@ // Make Pluggable Electron's facade available to the renderer on window.plugins -//@ts-ignore -const useFacade = require("../core/plugin-manager/facade"); +import { useFacade } from "./core/plugin/facade"; useFacade(); //@ts-ignore const { contextBridge, ipcRenderer } = require("electron"); @@ -19,6 +18,8 @@ contextBridge.exposeInMainWorld("electronAPI", { pluginPath: () => ipcRenderer.invoke("pluginPath"), + appDataPath: () => ipcRenderer.invoke("appDataPath"), + reloadPlugins: () => ipcRenderer.invoke("reloadPlugins"), appVersion: () => ipcRenderer.invoke("appVersion"), @@ -27,6 +28,8 @@ contextBridge.exposeInMainWorld("electronAPI", { relaunch: () => ipcRenderer.invoke("relaunch"), + openAppDirectory: () => ipcRenderer.invoke("openAppDirectory"), + deleteFile: (filePath: string) => ipcRenderer.invoke("deleteFile", filePath), installRemotePlugin: (pluginName: string) => @@ -35,6 +38,15 @@ contextBridge.exposeInMainWorld("electronAPI", { downloadFile: (url: string, path: string) => ipcRenderer.invoke("downloadFile", url, path), + pauseDownload: (fileName: string) => + ipcRenderer.invoke("pauseDownload", fileName), + + resumeDownload: (fileName: string) => + ipcRenderer.invoke("resumeDownload", fileName), + + abortDownload: (fileName: string) => + ipcRenderer.invoke("abortDownload", fileName), + onFileDownloadUpdate: (callback: any) => ipcRenderer.on("FILE_DOWNLOAD_UPDATE", callback), diff --git a/electron/tests/navigation.e2e.spec.ts b/electron/tests/navigation.e2e.spec.ts index 90220d5db..104333650 100644 --- a/electron/tests/navigation.e2e.spec.ts +++ b/electron/tests/navigation.e2e.spec.ts @@ -40,8 +40,12 @@ test("renders left navigation panel", async () => { expect(chatSection).toBe(false); // Home actions - const botBtn = await page.getByTestId("Bot").first().isEnabled(); + /* Disable unstable feature tests + ** const botBtn = await page.getByTestId("Bot").first().isEnabled(); + ** Enable back when it is whitelisted + */ + const myModelsBtn = await page.getByTestId("My Models").first().isEnabled(); const settingsBtn = await page.getByTestId("Settings").first().isEnabled(); - expect([botBtn, myModelsBtn, settingsBtn].filter((e) => !e).length).toBe(0); + expect([myModelsBtn, settingsBtn].filter((e) => !e).length).toBe(0); }); diff --git a/package.json b/package.json index 53c046c72..52ec0191e 100644 --- a/package.json +++ b/package.json @@ -4,13 +4,16 @@ "workspaces": { "packages": [ "electron", - "web" + "web", + "server" ], "nohoist": [ "electron", "electron/**", "web", - "web/**" + "web/**", + "server", + "server/**" ] }, "scripts": { @@ -19,20 +22,29 @@ "dev:electron": "yarn workspace jan dev", "dev:web": "yarn workspace jan-web dev", "dev": "concurrently --kill-others \"yarn dev:web\" \"wait-on http://localhost:3000 && yarn dev:electron\"", - "test-local": "yarn lint && yarn build && yarn test", + "test-local": "yarn lint && yarn build:test && yarn test", "build:core": "cd core && yarn install && yarn run build", "build:web": "yarn workspace jan-web build && cpx \"web/out/**\" \"electron/renderer/\"", "build:electron": "yarn workspace jan build", - "build:plugins": "rimraf ./electron/core/pre-install/*.tgz && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm install && npm run postinstall\" \"cd ./plugins/inference-plugin && npm install && npm run postinstall\" \"cd ./plugins/model-management-plugin && npm install && npm run postinstall\" \"cd ./plugins/monitoring-plugin && npm install && npm run postinstall\" && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm run build:publish\" \"cd ./plugins/inference-plugin && npm run build:publish\" \"cd ./plugins/model-management-plugin && npm run build:publish\" \"cd ./plugins/monitoring-plugin && npm run build:publish\"", - "build:plugins-darwin": "rimraf ./electron/core/pre-install/*.tgz && concurrently \"cd ./plugins/data-plugin && npm install && npm run build:deps && npm run postinstall\" \"cd ./plugins/inference-plugin && npm install && npm run postinstall\" \"cd ./plugins/model-management-plugin && npm install && npm run postinstall\" \"cd ./plugins/monitoring-plugin && npm install && npm run postinstall\" && chmod +x ./.github/scripts/auto-sign.sh && ./.github/scripts/auto-sign.sh && concurrently \"cd ./plugins/data-plugin && npm run build:publish\" \"cd ./plugins/inference-plugin && npm run build:publish\" \"cd ./plugins/model-management-plugin && npm run build:publish\" \"cd ./plugins/monitoring-plugin && npm run build:publish\"", + "build:electron:test": "yarn workspace jan build:test", + "build:pull-plugins": "rimraf ./electron/core/pre-install/*.tgz && cd ./electron/core/pre-install && npm pack @janhq/inference-plugin @janhq/data-plugin @janhq/model-management-plugin @janhq/monitoring-plugin", + "build:plugins": "rimraf ./electron/core/pre-install/*.tgz && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm install && npm run postinstall\" \"cd ./plugins/inference-plugin && npm install --ignore-scripts && npm run postinstall:dev\" \"cd ./plugins/model-management-plugin && npm install && npm run postinstall\" \"cd ./plugins/monitoring-plugin && npm install && npm run postinstall\" && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm run build:publish\" \"cd ./plugins/inference-plugin && npm run build:publish\" \"cd ./plugins/model-management-plugin && npm run build:publish\" \"cd ./plugins/monitoring-plugin && npm run build:publish\"", + "build:plugins-web": "rimraf ./electron/core/pre-install/*.tgz && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm install && npm run build:deps && npm run postinstall\" \"cd ./plugins/inference-plugin && npm install && npm run postinstall\" \"cd ./plugins/model-management-plugin && npm install && npm run postinstall\" \"cd ./plugins/monitoring-plugin && npm install && npm run postinstall\" && concurrently --kill-others-on-fail \"cd ./plugins/data-plugin && npm run build:publish\" \"cd ./plugins/inference-plugin && npm run build:publish\" \"cd ./plugins/model-management-plugin && npm run build:publish\" \"cd ./plugins/monitoring-plugin && npm run build:publish\"", "build": "yarn build:web && yarn build:electron", + "build:test": "yarn build:web && yarn build:electron:test", + "build:test-darwin": "yarn build:web && yarn workspace jan build:test-darwin", + "build:test-win32": "yarn build:web && yarn workspace jan build:test-win32", + "build:test-linux": "yarn build:web && yarn workspace jan build:test-linux", "build:darwin": "yarn build:web && yarn workspace jan build:darwin", "build:win32": "yarn build:web && yarn workspace jan build:win32", "build:linux": "yarn build:web && yarn workspace jan build:linux", "build:publish": "yarn build:web && yarn workspace jan build:publish", "build:publish-darwin": "yarn build:web && yarn workspace jan build:publish-darwin", "build:publish-win32": "yarn build:web && yarn workspace jan build:publish-win32", - "build:publish-linux": "yarn build:web && yarn workspace jan build:publish-linux" + "build:publish-linux": "yarn build:web && yarn workspace jan build:publish-linux", + "build:web-plugins": "yarn build:web && yarn build:plugins-web && mkdir -p \"./web/out/plugins/data-plugin\" && cp \"./plugins/data-plugin/dist/esm/index.js\" \"./web/out/plugins/data-plugin\" && mkdir -p \"./web/out/plugins/inference-plugin\" && cp \"./plugins/inference-plugin/dist/index.js\" \"./web/out/plugins/inference-plugin\" && mkdir -p \"./web/out/plugins/model-management-plugin\" && cp \"./plugins/model-management-plugin/dist/index.js\" \"./web/out/plugins/model-management-plugin\" && mkdir -p \"./web/out/plugins/monitoring-plugin\" && cp \"./plugins/monitoring-plugin/dist/index.js\" \"./web/out/plugins/monitoring-plugin\"", + "server:prod": "yarn workspace server build && yarn build:web-plugins && cpx \"web/out/**\" \"server/build/renderer/\" && mkdir -p ./server/build/@janhq && cp -r ./plugins/* ./server/build/@janhq", + "start:server": "yarn server:prod && node server/build/main.js" }, "devDependencies": { "concurrently": "^8.2.1", diff --git a/plugins/data-plugin/README.md b/plugins/data-plugin/README.md index 2197d9ad2..802883fab 100644 --- a/plugins/data-plugin/README.md +++ b/plugins/data-plugin/README.md @@ -1,8 +1,6 @@ -## Database handler plugin for Jan App - -**Notice**: please only install dependencies and run build using npm and not yarn. +## Jan data handler plugin - index.ts: Main entry point for the plugin. - module.ts: Defines the plugin module which would be executed by the main node process. -- package.json: Defines the plugin metadata. -- tsconfig.json: Defines the typescript configuration. +- package.json: Plugin & npm module manifest. + diff --git a/plugins/data-plugin/module.ts b/plugins/data-plugin/module.ts index b5497e978..21878d0f7 100644 --- a/plugins/data-plugin/module.ts +++ b/plugins/data-plugin/module.ts @@ -16,7 +16,7 @@ const dbs: Record = {}; */ function createCollection(name: string, schema?: { [key: string]: any }): Promise { return new Promise((resolve) => { - const dbPath = path.join(app.getPath("userData"), "databases"); + const dbPath = path.join(appPath(), "databases"); if (!fs.existsSync(dbPath)) fs.mkdirSync(dbPath); const db = new PouchDB(`${path.join(dbPath, name)}`); dbs[name] = db; @@ -226,6 +226,13 @@ function findMany( .then((data) => data.docs); // Return documents } +function appPath() { + if (app) { + return app.getPath("userData"); + } + return process.env.APPDATA || (process.platform == 'darwin' ? process.env.HOME + '/Library/Preferences' : process.env.HOME + "/.local/share"); +} + module.exports = { createCollection, deleteCollection, diff --git a/plugins/data-plugin/package.json b/plugins/data-plugin/package.json index a685039b1..dbdd31bbc 100644 --- a/plugins/data-plugin/package.json +++ b/plugins/data-plugin/package.json @@ -1,12 +1,14 @@ { "name": "@janhq/data-plugin", - "version": "1.0.16", + "version": "1.0.19", "description": "The Data Connector provides easy access to a data API using the PouchDB engine. It offers accessible data management capabilities.", "icon": "https://raw.githubusercontent.com/tailwindlabs/heroicons/88e98b0c2b458553fbadccddc2d2f878edc0387b/src/20/solid/circle-stack.svg", "main": "dist/esm/index.js", "module": "dist/cjs/module.js", "author": "Jan ", - "license": "MIT", + "license": "AGPL-3.0", + "supportCloudNative": true, + "url": "/plugins/data-plugin/index.js", "activationPoints": [ "init" ], diff --git a/plugins/inference-plugin/README.md b/plugins/inference-plugin/README.md index ae70eb4ec..455783efb 100644 --- a/plugins/inference-plugin/README.md +++ b/plugins/inference-plugin/README.md @@ -1,3 +1,7 @@ +# Jan inference plugin + +Created using Jan app example + # Create a Jan Plugin using Typescript Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 @@ -71,3 +75,4 @@ There are a few things to keep in mind when writing your plugin code: [documentation](https://github.com/janhq/jan/blob/main/core/README.md). So, what are you waiting for? Go ahead and start customizing your plugin! + diff --git a/plugins/inference-plugin/module.ts b/plugins/inference-plugin/module.ts index b851f0c4e..ed3be094d 100644 --- a/plugins/inference-plugin/module.ts +++ b/plugins/inference-plugin/module.ts @@ -3,7 +3,7 @@ const { app } = require("electron"); const { spawn } = require("child_process"); const fs = require("fs"); const tcpPortUsed = require("tcp-port-used"); -const { killPortProcess } = require("kill-port-process"); +const kill = require("kill-port"); const PORT = 3928; let subprocess = null; @@ -23,25 +23,27 @@ const initModel = (fileName) => { let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default let binaryName; - if (process.platform === "win32") { - // Todo: Need to check for CUDA support to switch between CUDA and non-CUDA binaries - binaryName = "nitro_start_windows.bat"; - } else if (process.platform === "darwin") { - // Mac OS platform - binaryName = - process.arch === "arm64" - ? "nitro_mac_arm64" - : "nitro_mac_intel"; - } else { - // Linux - // Todo: Need to check for CUDA support to switch between CUDA and non-CUDA binaries - binaryName = "nitro_start_linux.sh"; // For other platforms - } + if (process.platform === "win32") { + // Todo: Need to check for CUDA support to switch between CUDA and non-CUDA binaries + binaryName = "win-start.bat"; + } else if (process.platform === "darwin") { + // Mac OS platform + if (process.arch === "arm64") { + binaryFolder = path.join(binaryFolder, "mac-arm64") + } else { + binaryFolder = path.join(binaryFolder, "mac-x64") + } + binaryName = "nitro" + } else { + // Linux + // Todo: Need to check for CUDA support to switch between CUDA and non-CUDA binaries + binaryName = "linux-start.sh"; // For other platforms + } const binaryPath = path.join(binaryFolder, binaryName); - // Execute the binary - subprocess = spawn(binaryPath, { cwd: binaryFolder }); + // Execute the binary + subprocess = spawn(binaryPath,["0.0.0.0", PORT], { cwd: binaryFolder }); // Handle subprocess output subprocess.stdout.on("data", (data) => { @@ -61,7 +63,7 @@ const initModel = (fileName) => { }) .then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000)) .then(() => { - const llama_model_path = path.join(app.getPath("userData"), fileName); + const llama_model_path = path.join(appPath(), fileName); const config = { llama_model_path, @@ -102,11 +104,18 @@ function killSubprocess() { subprocess = null; console.log("Subprocess terminated."); } else { - killPortProcess(PORT); + kill(PORT, "tcp").then(console.log).catch(console.log); console.error("No subprocess is currently running."); } } +function appPath() { + if (app) { + return app.getPath("userData"); + } + return process.env.APPDATA || (process.platform == 'darwin' ? process.env.HOME + '/Library/Preferences' : process.env.HOME + "/.local/share"); +} + module.exports = { initModel, killSubprocess, diff --git a/plugins/inference-plugin/nitro/ggml-metal.metal b/plugins/inference-plugin/nitro/ggml-metal.metal deleted file mode 100644 index 99b9fd7a7..000000000 --- a/plugins/inference-plugin/nitro/ggml-metal.metal +++ /dev/null @@ -1,2526 +0,0 @@ -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) - -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; - - src0_ptr += ntg.x*nb00; - src1_ptr += ntg.x*nb10; - dst_ptr += ntg.x*nb0; - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant int64_t & nb [[buffer(27)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_mul( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig]; -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant int64_t & nb, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_silu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -constant float GELU_COEF_A = 0.044715f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_soft_max( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - // parallel max - float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { - lmax = MAX(lmax, psrc0[i00]); - } - const float max = simd_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - const float exp_psrc0 = exp(psrc0[i00] - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // whish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = simd_sum(lsum); - - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - pdst[i00] /= sum; - } -} - -kernel void kernel_soft_max_4( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - // parallel max - float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { - lmax4 = fmax(lmax4, psrc4[i00]); - } - float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - const float max = simd_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { - const float4 exp_psrc4 = exp(psrc4[i00] - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - const float sum = simd_sum(lsum); - - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { - pdst4[i00] /= sum; - } -} - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (tiisg == 0) { - sum[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - } - if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - device float * y_scalar = (device float *) y; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } - if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - float2 acc = 0.f; - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// giard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -#define NB_Q8_0 8 - -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -#define N_F32_F32 4 - -kernel void kernel_mul_mv_f32_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; - const int64_t im = tgpig.z; - - device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const half4 * x4 = (device const half4 *) x; - device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -#define N_F16_F32 4 - -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float m_k; - if (i2 < n_heads_log2_floor) { - m_k = pow(m0, i2 + 1); - } else { - m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); - } - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); - } -} - -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - -template -kernel void kernel_rope( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - const bool is_neox = mode & 2; - - device const int32_t * pos = src1; - - const int64_t p = pos[i2]; - - const float theta_0 = freq_scale * (float)p; - const float inv_ndims = -1.f/n_dims; - - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } - } - } -} - -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; - -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; - } - dst_ptr += ntg.x*nb0; - } -} - -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; - r[2] = q[j+1] & 63; - r[1] = q[j+4] & 63; - r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); - } - return r; -} - -//====================================== dot products ========================= - -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -#if QK_K == 256 -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} -#else -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> im; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } - -} -#endif - -#if QK_K == 256 -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01 [[buffer(4)]], - constant int64_t & ne02 [[buffer(5)]], - constant int64_t & ne10 [[buffer(9)]], - constant int64_t & ne12 [[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & gqa [[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#else -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - -#if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int im = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1u << (2*im); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; - } - } - -} - -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf = 0; - -#if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i=0;i<16;i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const half d = xb->d; - const half min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - half dl, ml; - uint8_t sc = xb->scales[il]; - -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; -#else - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const half qh_val = il<2 ? 16.h : 256.h; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - half sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -kernel void kernel_get_rows( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tptg[[threads_per_threadgroup]]) { - const int i = tgpig; - const int r = ((device int32_t *) src1)[i]; - - for (int ind = tiitg; ind < ne00/16; ind += tptg) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; - } -} - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - uint offset0 = im/gqa*nb02; - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif - -typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ - constant uint64_t &, constant uint64_t &, uint, uint, uint); - -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; - -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar *, uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/plugins/inference-plugin/nitro/libcrypto-3-x64.dll b/plugins/inference-plugin/nitro/libcrypto-3-x64.dll deleted file mode 100644 index b9223ec18..000000000 Binary files a/plugins/inference-plugin/nitro/libcrypto-3-x64.dll and /dev/null differ diff --git a/plugins/inference-plugin/nitro/libssl-3-x64.dll b/plugins/inference-plugin/nitro/libssl-3-x64.dll deleted file mode 100644 index d48518e4f..000000000 Binary files a/plugins/inference-plugin/nitro/libssl-3-x64.dll and /dev/null differ diff --git a/plugins/inference-plugin/nitro/linux-cpu/.gitkeep b/plugins/inference-plugin/nitro/linux-cpu/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/linux-cuda/.gitkeep b/plugins/inference-plugin/nitro/linux-cuda/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/linux-start.sh b/plugins/inference-plugin/nitro/linux-start.sh new file mode 100644 index 000000000..647d3b254 --- /dev/null +++ b/plugins/inference-plugin/nitro/linux-start.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +# Attempt to run the nitro_linux_amd64_cuda file and if it fails, run nitro_linux_amd64 +cd linux-cuda +./nitro "$@" || (echo "nitro_linux_amd64_cuda encountered an error, attempting to run nitro_linux_amd64..." && cd ../linux-cpu && ./nitro "$@") diff --git a/plugins/inference-plugin/nitro/llama.dll b/plugins/inference-plugin/nitro/llama.dll deleted file mode 100644 index 73a1dd038..000000000 Binary files a/plugins/inference-plugin/nitro/llama.dll and /dev/null differ diff --git a/plugins/inference-plugin/nitro/mac-arm64/.gitkeep b/plugins/inference-plugin/nitro/mac-arm64/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/mac-x64/.gitkeep b/plugins/inference-plugin/nitro/mac-x64/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/nitro_linux_amd64 b/plugins/inference-plugin/nitro/nitro_linux_amd64 deleted file mode 100755 index dc3e2c73b..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_linux_amd64 and /dev/null differ diff --git a/plugins/inference-plugin/nitro/nitro_linux_amd64_cuda b/plugins/inference-plugin/nitro/nitro_linux_amd64_cuda deleted file mode 100755 index b9ba0b68d..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_linux_amd64_cuda and /dev/null differ diff --git a/plugins/inference-plugin/nitro/nitro_mac_arm64 b/plugins/inference-plugin/nitro/nitro_mac_arm64 deleted file mode 100755 index ffe5beb8c..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_mac_arm64 and /dev/null differ diff --git a/plugins/inference-plugin/nitro/nitro_mac_intel b/plugins/inference-plugin/nitro/nitro_mac_intel deleted file mode 100755 index c74b1aec3..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_mac_intel and /dev/null differ diff --git a/plugins/inference-plugin/nitro/nitro_start_linux.sh b/plugins/inference-plugin/nitro/nitro_start_linux.sh deleted file mode 100755 index 19ac36f21..000000000 --- a/plugins/inference-plugin/nitro/nitro_start_linux.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -#!/bin/bash - -# Attempt to run the nitro_linux_amd64_cuda file and if it fails, run nitro_linux_amd64 -./nitro_linux_amd64_cuda || (echo "nitro_linux_amd64_cuda encountered an error, attempting to run nitro_linux_amd64..." && ./nitro_linux_amd64) diff --git a/plugins/inference-plugin/nitro/nitro_windows_amd64.exe b/plugins/inference-plugin/nitro/nitro_windows_amd64.exe deleted file mode 100644 index 9b2131fbf..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_windows_amd64.exe and /dev/null differ diff --git a/plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe b/plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe deleted file mode 100644 index 36c3d8c04..000000000 Binary files a/plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe and /dev/null differ diff --git a/plugins/inference-plugin/nitro/version.txt b/plugins/inference-plugin/nitro/version.txt new file mode 100644 index 000000000..7693c96bf --- /dev/null +++ b/plugins/inference-plugin/nitro/version.txt @@ -0,0 +1 @@ +0.1.3 \ No newline at end of file diff --git a/plugins/inference-plugin/nitro/win-cpu/.gitkeep b/plugins/inference-plugin/nitro/win-cpu/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/win-cuda/.gitkeep b/plugins/inference-plugin/nitro/win-cuda/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/inference-plugin/nitro/nitro_start_windows.bat b/plugins/inference-plugin/nitro/win-start.bat old mode 100755 new mode 100644 similarity index 80% rename from plugins/inference-plugin/nitro/nitro_start_windows.bat rename to plugins/inference-plugin/nitro/win-start.bat index 4f7cb4247..7d1e7dc8d --- a/plugins/inference-plugin/nitro/nitro_start_windows.bat +++ b/plugins/inference-plugin/nitro/win-start.bat @@ -1,10 +1,12 @@ @echo off rem Attempt to run nitro_windows_amd64_cuda.exe -nitro_windows_amd64_cuda.exe +cd win-cuda +nitro.exe rem Check the exit code of the previous command if %errorlevel% neq 0 ( echo nitro_windows_amd64_cuda.exe encountered an error, attempting to run nitro_windows_amd64.exe... - nitro_windows_amd64.exe + cd ..\win-cpu + nitro.exe ) diff --git a/plugins/inference-plugin/nitro/zlib.dll b/plugins/inference-plugin/nitro/zlib.dll deleted file mode 100644 index dfec1a0b7..000000000 Binary files a/plugins/inference-plugin/nitro/zlib.dll and /dev/null differ diff --git a/plugins/inference-plugin/package.json b/plugins/inference-plugin/package.json index 9d2fbd50b..b53858547 100644 --- a/plugins/inference-plugin/package.json +++ b/plugins/inference-plugin/package.json @@ -1,18 +1,27 @@ { "name": "@janhq/inference-plugin", - "version": "1.0.14", + "version": "1.0.20", "description": "Inference Plugin, powered by @janhq/nitro, bring a high-performance Llama model inference in pure C++.", "icon": "https://raw.githubusercontent.com/tailwindlabs/heroicons/88e98b0c2b458553fbadccddc2d2f878edc0387b/src/20/solid/command-line.svg", "main": "dist/index.js", "module": "dist/module.js", "author": "Jan ", - "license": "MIT", + "license": "AGPL-3.0", + "supportCloudNative": true, + "url": "/plugins/inference-plugin/index.js", "activationPoints": [ "init" ], "scripts": { "build": "tsc -b . && webpack --config webpack.config.js", - "postinstall": "rimraf *.tgz --glob && npm run build && rimraf dist/nitro/* && cpx \"nitro/**\" \"dist/nitro\"", + "downloadnitro:linux-cpu": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.zip -e --strip 1 -o ./nitro/linux-cpu && chmod +x ./nitro/linux-cpu/nitro", + "downloadnitro:linux-cuda": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda.zip -e --strip 1 -o ./nitro/linux-cuda && chmod +x ./nitro/linux-cuda/nitro", + "downloadnitro:mac-arm64": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.zip -e --strip 1 -o ./nitro/mac-arm64 && chmod +x ./nitro/mac-arm64/nitro", + "downloadnitro:mac-x64": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.zip -e --strip 1 -o ./nitro/mac-x64 && chmod +x ./nitro/mac-x64/nitro", + "downloadnitro:win-cpu": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-win-amd64.zip -e --strip 1 -o ./nitro/win-cpu", + "downloadnitro:win-cuda": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-win-amd64-cuda.zip -e --strip 1 -o ./nitro/win-cuda", + "postinstall": "rimraf *.tgz --glob && npm run build && npm run downloadnitro:linux-cpu && npm run downloadnitro:linux-cuda && npm run downloadnitro:mac-arm64 && npm run downloadnitro:mac-x64 && npm run downloadnitro:win-cpu && npm run downloadnitro:win-cuda && rimraf dist/nitro/* && cpx \"nitro/**\" \"dist/nitro\"", + "postinstall:dev": "rimraf *.tgz --glob && npm run build && npm run downloadnitro:mac-arm64 && rimraf dist/nitro/* && cpx \"nitro/**\" \"dist/nitro\"", "build:publish": "npm pack && cpx *.tgz ../../electron/core/pre-install" }, "exports": { @@ -27,15 +36,12 @@ }, "dependencies": { "@janhq/core": "^0.1.6", - "kill-port-process": "^3.2.0", + "download-cli": "^1.1.1", + "kill-port": "^2.0.1", "rxjs": "^7.8.1", "tcp-port-used": "^1.0.2", "ts-loader": "^9.5.0" }, - "bundledDependencies": [ - "tcp-port-used", - "kill-port-process" - ], "engines": { "node": ">=18.0.0" }, @@ -43,5 +49,9 @@ "dist/*", "package.json", "README.md" + ], + "bundleDependencies": [ + "tcp-port-used", + "kill-port" ] } diff --git a/plugins/inference-plugin/webpack.config.js b/plugins/inference-plugin/webpack.config.js index 2d9e18ab2..0f5f59d7f 100644 --- a/plugins/inference-plugin/webpack.config.js +++ b/plugins/inference-plugin/webpack.config.js @@ -19,7 +19,7 @@ module.exports = { new webpack.DefinePlugin({ PLUGIN_NAME: JSON.stringify(packageJson.name), MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`), - INFERENCE_URL: JSON.stringify("http://127.0.0.1:3928/inferences/llamacpp/chat_completion"), + INFERENCE_URL: JSON.stringify(process.env.INFERENCE_URL || "http://127.0.0.1:3928/inferences/llamacpp/chat_completion"), }), ], output: { @@ -30,5 +30,8 @@ module.exports = { resolve: { extensions: [".ts", ".js"], }, + optimization: { + minimize: false, + }, // Add loaders and other configuration as needed for your project }; diff --git a/plugins/model-management-plugin/README.md b/plugins/model-management-plugin/README.md new file mode 100644 index 000000000..516bbec8b --- /dev/null +++ b/plugins/model-management-plugin/README.md @@ -0,0 +1,78 @@ +# Jan Model Management plugin + +Created using Jan app example + +# Create a Jan Plugin using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 + +## Create Your Own Plugin + +To create your own plugin, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your plugin directory now + +## Update the Plugin Metadata + +The [`package.json`](package.json) file defines metadata about your plugin, such as +plugin name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your plugin. + +## Update the Plugin Code + +The [`src/`](./src/) directory is the heart of your plugin! This contains the +source code that will be run when your plugin extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your plugin code: + +- Most Jan Plugin Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { core } from "@janhq/core"; + + function onStart(): Promise { + return core.invokePluginFunc(MODULE_PATH, "run", 0); + } + ``` + + For more information about the Jan Plugin Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your plugin! + diff --git a/plugins/model-management-plugin/index.ts b/plugins/model-management-plugin/index.ts index f20814fd3..b0fee7630 100644 --- a/plugins/model-management-plugin/index.ts +++ b/plugins/model-management-plugin/index.ts @@ -5,11 +5,56 @@ import { downloadFile, deleteFile, store, + EventName, + events } from "@janhq/core"; import { parseToModel } from "./helper"; -const downloadModel = (product) => +const downloadModel = (product) => { downloadFile(product.downloadUrl, product.fileName); + checkDownloadProgress(product.fileName); +} + +async function checkDownloadProgress(fileName: string) { + if (typeof window !== "undefined" && typeof (window as any).electronAPI === "undefined") { + const intervalId = setInterval(() => { + fetchDownloadProgress(fileName, intervalId); + }, 3000); + } +} + +async function fetchDownloadProgress(fileName: string, intervalId: NodeJS.Timeout): Promise { + const response = await fetch("/api/v1/downloadProgress", { + method: 'POST', + body: JSON.stringify({ fileName: fileName }), + headers: { 'Content-Type': 'application/json', 'Authorization': '' } + }); + + if (!response.ok) { + events.emit(EventName.OnDownloadError, null); + clearInterval(intervalId); + return; + } + const json = await response.json(); + if (isEmptyObject(json)) { + if (!fileName && intervalId) { + clearInterval(intervalId); + } + return Promise.resolve(""); + } + if (json.success === true) { + events.emit(EventName.OnDownloadSuccess, json); + clearInterval(intervalId); + return Promise.resolve(""); + } else { + events.emit(EventName.OnDownloadUpdate, json); + return Promise.resolve(json.fileName); + } +} + +function isEmptyObject(ojb: any): boolean { + return Object.keys(ojb).length === 0; +} const deleteModel = (path) => deleteFile(path); @@ -87,6 +132,9 @@ function getModelById(modelId: string): Promise { function onStart() { store.createCollection("models", {}); + if (!(window as any)?.electronAPI) { + fetchDownloadProgress(null, null).then((fileName: string) => fileName && checkDownloadProgress(fileName)); + } } // Register all the above functions and objects with the relevant extension points diff --git a/plugins/model-management-plugin/package.json b/plugins/model-management-plugin/package.json index 21ce35f38..faa849506 100644 --- a/plugins/model-management-plugin/package.json +++ b/plugins/model-management-plugin/package.json @@ -1,12 +1,14 @@ { "name": "@janhq/model-management-plugin", - "version": "1.0.9", + "version": "1.0.13", "description": "Model Management Plugin provides model exploration and seamless downloads", "icon": "https://raw.githubusercontent.com/tailwindlabs/heroicons/88e98b0c2b458553fbadccddc2d2f878edc0387b/src/20/solid/queue-list.svg", "main": "dist/index.js", "module": "dist/module.js", "author": "Jan ", - "license": "MIT", + "license": "AGPL-3.0", + "supportCloudNative": true, + "url": "/plugins/model-management-plugin/index.js", "activationPoints": [ "init" ], diff --git a/plugins/monitoring-plugin/README.md b/plugins/monitoring-plugin/README.md new file mode 100644 index 000000000..1617b9b13 --- /dev/null +++ b/plugins/monitoring-plugin/README.md @@ -0,0 +1,78 @@ +# Jan Monitoring plugin + +Created using Jan app example + +# Create a Jan Plugin using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 + +## Create Your Own Plugin + +To create your own plugin, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your plugin directory now + +## Update the Plugin Metadata + +The [`package.json`](package.json) file defines metadata about your plugin, such as +plugin name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your plugin. + +## Update the Plugin Code + +The [`src/`](./src/) directory is the heart of your plugin! This contains the +source code that will be run when your plugin extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your plugin code: + +- Most Jan Plugin Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { core } from "@janhq/core"; + + function onStart(): Promise { + return core.invokePluginFunc(MODULE_PATH, "run", 0); + } + ``` + + For more information about the Jan Plugin Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your plugin! + diff --git a/plugins/monitoring-plugin/package.json b/plugins/monitoring-plugin/package.json index 49f9e4b46..e33408d68 100644 --- a/plugins/monitoring-plugin/package.json +++ b/plugins/monitoring-plugin/package.json @@ -1,12 +1,14 @@ { "name": "@janhq/monitoring-plugin", - "version": "1.0.6", + "version": "1.0.9", "description": "Utilizing systeminformation, it provides essential System and OS information retrieval", "icon": "https://raw.githubusercontent.com/tailwindlabs/heroicons/88e98b0c2b458553fbadccddc2d2f878edc0387b/src/20/solid/cpu-chip.svg", "main": "dist/index.js", "module": "dist/module.js", "author": "Jan ", - "license": "MIT", + "license": "AGPL-3.0", + "supportCloudNative": true, + "url": "/plugins/monitoring-plugin/index.js", "activationPoints": [ "init" ], diff --git a/plugins/monitoring-plugin/webpack.config.js b/plugins/monitoring-plugin/webpack.config.js index 4583497e7..25d51ffa0 100644 --- a/plugins/monitoring-plugin/webpack.config.js +++ b/plugins/monitoring-plugin/webpack.config.js @@ -29,5 +29,8 @@ module.exports = { resolve: { extensions: [".ts", ".js"], }, + optimization: { + minimize: false, + }, // Add loaders and other configuration as needed for your project }; diff --git a/server/main.ts b/server/main.ts new file mode 100644 index 000000000..b98b4097c --- /dev/null +++ b/server/main.ts @@ -0,0 +1,179 @@ +import express, { Express, Request, Response, NextFunction } from 'express' +import cors from "cors"; +import { resolve } from "path"; +const fs = require("fs"); +const progress = require("request-progress"); +const path = require("path"); +const request = require("request"); + +// Create app dir +const userDataPath = appPath(); +if (!fs.existsSync(userDataPath)) fs.mkdirSync(userDataPath); + +interface ProgressState { + percent?: number; + speed?: number; + size?: { + total: number; + transferred: number; + }; + time?: { + elapsed: number; + remaining: number; + }; + success?: boolean | undefined; + fileName: string; +} + +const options: cors.CorsOptions = { origin: "*" }; +const requiredModules: Record = {}; +const port = process.env.PORT || 4000; +const dataDir = __dirname; +type DownloadProgress = Record; +const downloadProgress: DownloadProgress = {}; +const app: Express = express() +app.use(express.static(dataDir + '/renderer')) +app.use(cors(options)) +app.use(express.json()); + +/** + * Execute a plugin module function via API call + * + * @param modulePath path to module name to import + * @param method function name to execute. The methods "deleteFile" and "downloadFile" will call the server function {@link deleteFile}, {@link downloadFile} instead of the plugin function. + * @param args arguments to pass to the function + * @returns Promise + * + */ +app.post('/api/v1/invokeFunction', (req: Request, res: Response, next: NextFunction): void => { + const method = req.body["method"]; + const args = req.body["args"]; + switch (method) { + case "deleteFile": + deleteFile(args).then(() => res.json(Object())).catch((err: any) => next(err)); + break; + case "downloadFile": + downloadFile(args.downloadUrl, args.fileName).then(() => res.json(Object())).catch((err: any) => next(err)); + break; + default: + const result = invokeFunction(req.body["modulePath"], method, args) + if (typeof result === "undefined") { + res.json(Object()) + } else { + result?.then((result: any) => { + res.json(result) + }).catch((err: any) => next(err)); + } + } +}); + +app.post('/api/v1/downloadProgress', (req: Request, res: Response): void => { + const fileName = req.body["fileName"]; + if (fileName && downloadProgress[fileName]) { + res.json(downloadProgress[fileName]) + return; + } else { + const obj = downloadingFile(); + if (obj) { + res.json(obj) + return; + } + } + res.json(Object()); +}); + +app.use((err: Error, req: Request, res: Response, next: NextFunction): void => { + console.error("ErrorHandler", req.url, req.body, err); + res.status(500); + res.json({ error: err?.message ?? "Internal Server Error" }) +}); + +app.listen(port, () => console.log(`Application is running on port ${port}`)); + + +async function invokeFunction(modulePath: string, method: string, args: any): Promise { + console.log(modulePath, method, args); + const module = require(/* webpackIgnore: true */ path.join( + dataDir, + "", + modulePath + )); + requiredModules[modulePath] = module; + if (typeof module[method] === "function") { + return module[method](...args); + } else { + return Promise.resolve(); + } +} + +function downloadModel(downloadUrl: string, fileName: string): void { + const userDataPath = appPath(); + const destination = resolve(userDataPath, fileName); + console.log("Download file", fileName, "to", destination); + progress(request(downloadUrl), {}) + .on("progress", function (state: any) { + downloadProgress[fileName] = { + ...state, + fileName, + success: undefined + }; + console.log("downloading file", fileName, (state.percent * 100).toFixed(2) + '%'); + }) + .on("error", function (err: Error) { + downloadProgress[fileName] = { + ...downloadProgress[fileName], + success: false, + fileName: fileName, + }; + }) + .on("end", function () { + downloadProgress[fileName] = { + success: true, + fileName: fileName, + }; + }) + .pipe(fs.createWriteStream(destination)); +} + +function deleteFile(filePath: string): Promise { + const userDataPath = appPath(); + const fullPath = resolve(userDataPath, filePath); + return new Promise((resolve, reject) => { + fs.unlink(fullPath, function (err: any) { + if (err && err.code === "ENOENT") { + reject(Error(`File does not exist: ${err}`)); + } else if (err) { + reject(Error(`File delete error: ${err}`)); + } else { + console.log(`Delete file ${filePath} from ${fullPath}`) + resolve(); + } + }); + }) +} + +function downloadingFile(): ProgressState | undefined { + const obj = Object.values(downloadProgress).find(obj => obj && typeof obj.success === "undefined") + return obj +} + + +async function downloadFile(downloadUrl: string, fileName: string): Promise { + return new Promise((resolve, reject) => { + const obj = downloadingFile(); + if (obj) { + reject(Error(obj.fileName + " is being downloaded!")) + return; + }; + (async () => { + downloadModel(downloadUrl, fileName); + })().catch(e => { + console.error("downloadModel", fileName, e); + }); + resolve(); + }); +} + +function appPath(): string { + return process.env.APPDATA || (process.platform == 'darwin' ? process.env.HOME + '/Library/Preferences' : process.env.HOME + "/.local/share") +} \ No newline at end of file diff --git a/server/nodemon.json b/server/nodemon.json new file mode 100644 index 000000000..fa415fa52 --- /dev/null +++ b/server/nodemon.json @@ -0,0 +1,5 @@ +{ + "watch": [ + "main.ts" + ] +} \ No newline at end of file diff --git a/server/package.json b/server/package.json new file mode 100644 index 000000000..895cae2b9 --- /dev/null +++ b/server/package.json @@ -0,0 +1,26 @@ +{ + "name": "server", + "version": "1.0.0", + "main": "index.js", + "license": "MIT", + "dependencies": { + "cors": "^2.8.5", + "electron": "^26.2.1", + "express": "^4.18.2", + "request": "^2.88.2", + "request-progress": "^3.0.0" + }, + "devDependencies": { + "@types/cors": "^2.8.14", + "@types/express": "^4.17.18", + "@types/node": "^20.8.2", + "nodemon": "^3.0.1", + "ts-node": "^10.9.1", + "typescript": "^5.2.2" + }, + "scripts": { + "build": "tsc --project ./", + "dev": "nodemon main.ts", + "prod": "node build/main.js" + } +} diff --git a/server/tsconfig.json b/server/tsconfig.json new file mode 100644 index 000000000..a79afcdfe --- /dev/null +++ b/server/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "es5", + "module": "commonjs", + "noImplicitAny": true, + "sourceMap": true, + "strict": true, + "outDir": "./build", + "rootDir": "./", + "noEmitOnError": true, + "baseUrl": ".", + "allowJs": true, + "paths": { "*": ["node_modules/*"] }, + "typeRoots": ["node_modules/@types"], + "esModuleInterop": true + }, + "include": ["./**/*.ts"], + "exclude": ["core", "build", "dist", "tests"] + } \ No newline at end of file diff --git a/specs/Makefile b/specs/Makefile new file mode 100644 index 000000000..4646dfdcd --- /dev/null +++ b/specs/Makefile @@ -0,0 +1,14 @@ +spec: + @echo "Initiating a Spec..." + @last_number=$$(ls $(CURDIR)/jan-[0-9][0-9][0-9]-* | sort -V | tail -n 1 | cut -d '-' -f 2); \ + last_number=$$(echo $$last_number | sed 's/^0*//'); \ + next_number=$$(printf "%03d" $$(( $$last_number + 1 ))); \ + read -p "Enter Spec title: " title; \ + title=$$(echo $$title | tr ' ' '-'); \ + cp $(CURDIR)/spec-template.md $(CURDIR)/jan-$$next_number-$$title.md; \ + date=$$(date +%Y-%m-%d); \ + usernames=$$(git config user.name); \ + sed -i '' 's/{SPEC-NUM}/'$$next_number'/g' $(CURDIR)/jan-$$next_number-$$title.md; \ + sed -i '' 's/{TITLE}/'$$title'/g' $(CURDIR)/jan-$$next_number-$$title.md; \ + sed -i '' 's/{DATE}/'$$date'/g' $(CURDIR)/jan-$$next_number-$$title.md; \ + sed -i '' 's/{USERNAMES}/'$$usernames'/g' $(CURDIR)/jan-$$next_number-$$title.md \ No newline at end of file diff --git a/adr/README.md b/specs/README.md similarity index 63% rename from adr/README.md rename to specs/README.md index ae6691594..25b342ef5 100644 --- a/adr/README.md +++ b/specs/README.md @@ -1,7 +1,8 @@ -# Architectural Decision Records (ADR) +# Jan Improvement Proposals -This is a repo of key architecture decisions for Jan. [Read more about ADRs](https://github.com/joelparkerhenderson/architecture-decision-record) +This is a repo of key architecture decisions for Jan. +[Read more about ADRs](https://github.com/joelparkerhenderson/architecture-decision-record) ### Get started: diff --git a/adr/adr-001-jan-deployable-cloud-native.md b/specs/adrs/adr-001-jan-deployable-cloud-native.md similarity index 100% rename from adr/adr-001-jan-deployable-cloud-native.md rename to specs/adrs/adr-001-jan-deployable-cloud-native.md diff --git a/adr/adr-002-jan-ai-apps.md b/specs/adrs/adr-002-jan-ai-apps.md similarity index 100% rename from adr/adr-002-jan-ai-apps.md rename to specs/adrs/adr-002-jan-ai-apps.md diff --git a/adr/adr-003-jan-plugins.md b/specs/adrs/adr-003-jan-plugins.md similarity index 100% rename from adr/adr-003-jan-plugins.md rename to specs/adrs/adr-003-jan-plugins.md diff --git a/adr/adr-004-UI-Service.md b/specs/adrs/adr-004-UI-Service.md similarity index 100% rename from adr/adr-004-UI-Service.md rename to specs/adrs/adr-004-UI-Service.md diff --git a/adr/adr-005-model-installation.md b/specs/adrs/adr-005-model-installation.md similarity index 100% rename from adr/adr-005-model-installation.md rename to specs/adrs/adr-005-model-installation.md diff --git a/adr/adr-006-jan-core-module.md b/specs/adrs/adr-006-jan-core-module.md similarity index 100% rename from adr/adr-006-jan-core-module.md rename to specs/adrs/adr-006-jan-core-module.md diff --git a/adr/adr-007-jan-plugin-catalog.md b/specs/adrs/adr-007-jan-plugin-catalog.md similarity index 100% rename from adr/adr-007-jan-plugin-catalog.md rename to specs/adrs/adr-007-jan-plugin-catalog.md diff --git a/adr/adr-008-Extensible-Jan-with-Docker.md b/specs/adrs/adr-008-Extensible-Jan-with-Docker.md similarity index 100% rename from adr/adr-008-Extensible-Jan-with-Docker.md rename to specs/adrs/adr-008-Extensible-Jan-with-Docker.md diff --git a/adr/images/adr-001-01.png b/specs/images/adr-001-01.png similarity index 100% rename from adr/images/adr-001-01.png rename to specs/images/adr-001-01.png diff --git a/adr/images/adr-001-02.png b/specs/images/adr-001-02.png similarity index 100% rename from adr/images/adr-001-02.png rename to specs/images/adr-001-02.png diff --git a/adr/images/adr-002-01.png b/specs/images/adr-002-01.png similarity index 100% rename from adr/images/adr-002-01.png rename to specs/images/adr-002-01.png diff --git a/adr/images/adr-002-02.png b/specs/images/adr-002-02.png similarity index 100% rename from adr/images/adr-002-02.png rename to specs/images/adr-002-02.png diff --git a/adr/images/adr-003-01.png b/specs/images/adr-003-01.png similarity index 100% rename from adr/images/adr-003-01.png rename to specs/images/adr-003-01.png diff --git a/adr/images/jan-core-module.png b/specs/images/jan-core-module.png similarity index 100% rename from adr/images/jan-core-module.png rename to specs/images/jan-core-module.png diff --git a/adr/images/jan-plugin-catalog.png b/specs/images/jan-plugin-catalog.png similarity index 100% rename from adr/images/jan-plugin-catalog.png rename to specs/images/jan-plugin-catalog.png diff --git a/adr/images/jan-ui-framework.png b/specs/images/jan-ui-framework.png similarity index 100% rename from adr/images/jan-ui-framework.png rename to specs/images/jan-ui-framework.png diff --git a/specs/jan-001-log-framework.md b/specs/jan-001-log-framework.md new file mode 100644 index 000000000..5af88d423 --- /dev/null +++ b/specs/jan-001-log-framework.md @@ -0,0 +1,101 @@ +# jan-001: Application Logs Framework + +| Proposal | jan-001 | +| ---------- | ----------------------------------------------------- | +| Title | App Logging | +| Authors | @louis-jan | +| Permalink | | +| Discussion | [issue #528](https://github.com/janhq/jan/issues/528) | +| Status | Idea | + +## Changelog + +| Date | Author | Changes | +| ------------ | ---------- | ------------- | +| Nov 2nd 2023 | @louis-jan | Initial Draft | + +## Summary + +This proposal suggests the implementation of an "App logging as file and log window" feature, which aims to address the problem of limited visibility into the operation of a production application. Currently, logs (info, verbose, error) are hidden, making it challenging for both users and developers to debug and support the application. The proposed solution involves logging application-wide activities to a file while also enabling real-time log monitoring through a dedicated log window within the application. + +## Motivation + +### Problem Description +The lack of proper logging in production applications results in several challenges: + +1. Debugging Difficulty: When an issue arises in a production environment, developers have limited access to essential information about what happened, making it challenging to diagnose and resolve problems effectively. +2. Support Challenges: Users often encounter errors or unexpected behavior, and support teams struggle to gather the necessary logs to understand the issue and provide a solution promptly. +3. Lack of Real-time Insights: Real-time monitoring is essential for identifying and responding to critical events. The absence of a log window within the application prevents timely reactions to events. + +### Use Case Example +Consider an e-commerce application. In the current state, when a user faces an issue during checkout, there's no easy way for support or development teams to see what went wrong in real time. This results in frustration for the user and a loss of business for the company + +```ts +# Current Status (Without the Feature) +try: + # Checkout logic + # ... +except Exception as e: + # Error handling + console.log(err) + # Insufficient logging +``` + +Without proper logging, it is challenging to diagnose the issue and provide immediate assistance. + +## Proposed solution + +### High-level overview +The proposed solution introduces the following key changes: + +1. Application-wide Logging: Implement a logging mechanism that logs application-wide activities to a designated log file. This ensures that all relevant information is captured for later analysis and debugging. +2. Real-time Log Window: Create a log window within the application that displays log entries in real time. Users and developers can open this window to monitor logs, allowing them to react promptly to events and errors. + +```ts +# With the Proposed Feature +try: + # Checkout logic + # ... +except Exception as e: + # Error handling + log.error(f"Error when downloading model: {e}") + # Proper logging + +``` + +![Image](https://github.com/janhq/jan/assets/133622055/b60f6976-8138-438e-aa4f-7e103037e124) + + +### Specification + +- The logging system will support different log levels (e.g., info, verbose, error) to ensure that the right level of detail is captured. +- Log entries will be timestamped and categorized to aid in the analysis and debugging process. +- The log window will provide options for filtering and searching log entries for ease of use. + + +### Compatibility + +This proposal aims to preserve backward compatibility by ensuring that the new logging system does not break existing functionality or affect existing applications negatively. It should not alter the semantics of valid programs. + + +### Other concerns + +- Implementation: Careful consideration should be given to the choice of logging framework and implementation details. +- Security: Access to logs and log window functionality should be restricted to authorized users to prevent potential security risks. + +### Open questions + +- What will be the default log file location, and how will it be configurable? +- Should log entries be persisted and rotated over time to prevent excessive file size? + +## Alternatives + +Alternative approaches may involve integrating with existing third-party logging systems or cloud-based log management platforms. However, this proposal focuses on a built-in solution for application-wide logging and real-time monitoring. + +## Related work + +This proposal is inspired by similar features in various application development frameworks and tools. + +## FAQ + +No frequently asked questions at this time. \ No newline at end of file diff --git a/specs/spec-template.md b/specs/spec-template.md new file mode 100644 index 000000000..32929affb --- /dev/null +++ b/specs/spec-template.md @@ -0,0 +1,33 @@ +# jan-{SPEC-NUM}: {TITLE} + +| Proposal | jan-{SPEC-NUM} | +| ---------- | -------------- | +| Title | {TITLE} | +| Authors | | +| Permalink | | +| Discussion | | +| Status | Idea | + +## Changelog + +| Date | Author | Changes | +| ---- | ------ | ------------- | +| | | Initial Draft | + +## Abstract + +Summary. Please keep it very short. + +## Motivation + +Why? + +## Specification + +What, how? +- UX Mockups +- Code Interfaces + +## Appendix + +Everything else goes here. \ No newline at end of file diff --git a/web/app/_components/ConfirmationModal/index.tsx b/web/app/_components/ConfirmationModal/index.tsx new file mode 100644 index 000000000..baac7840d --- /dev/null +++ b/web/app/_components/ConfirmationModal/index.tsx @@ -0,0 +1,87 @@ +import React, { Fragment } from 'react' +import { Dialog, Transition } from '@headlessui/react' +import { QuestionMarkCircleIcon } from '@heroicons/react/24/outline' +import { PrimitiveAtom, useAtom } from 'jotai' + +interface Props { + atom: PrimitiveAtom + title: string + description: string + onConfirm: () => void +} + +const ConfirmationModal: React.FC = ({ atom, title, description, onConfirm }) => { + const [show, setShow] = useAtom(atom) + + return ( + + + +
+ + +
+
+ + +
+
+
+
+ + {title} + +
+

{description}

+
+
+
+
+ + +
+
+
+
+
+
+
+ ) +} + +export default React.memo(ConfirmationModal) diff --git a/web/app/_components/ExploreModelItemHeader/index.tsx b/web/app/_components/ExploreModelItemHeader/index.tsx index 55786dbc4..81228d0da 100644 --- a/web/app/_components/ExploreModelItemHeader/index.tsx +++ b/web/app/_components/ExploreModelItemHeader/index.tsx @@ -11,6 +11,8 @@ import { MainViewState, setMainViewStateAtom, } from '@helpers/atoms/MainView.atom' +import ConfirmationModal from '../ConfirmationModal' +import { showingCancelDownloadModalAtom } from '@helpers/atoms/Modal.atom' type Props = { suitableModel: ModelVersion @@ -31,6 +33,9 @@ const ExploreModelItemHeader: React.FC = ({ ) const downloadState = useAtomValue(downloadAtom) const setMainViewState = useSetAtom(setMainViewStateAtom) + const setShowingCancelDownloadModal = useSetAtom( + showingCancelDownloadModalAtom + ) useEffect(() => { getPerformanceForModel(suitableModel) @@ -70,17 +75,30 @@ const ExploreModelItemHeader: React.FC = ({ // downloading downloadButton = ( ) } + let cancelDownloadModal = + downloadState != null ? ( + { + window.coreAPI?.abortDownload(downloadState?.fileName) + }} + /> + ) : ( + <> + ) + return (
@@ -90,6 +108,7 @@ const ExploreModelItemHeader: React.FC = ({ )}
{downloadButton} + {cancelDownloadModal}
) } diff --git a/web/app/_components/HistoryList/index.tsx b/web/app/_components/HistoryList/index.tsx index 3cbb482f6..880521597 100644 --- a/web/app/_components/HistoryList/index.tsx +++ b/web/app/_components/HistoryList/index.tsx @@ -15,32 +15,33 @@ const HistoryList: React.FC = () => { useEffect(() => { getUserConversations() + // eslint-disable-next-line react-hooks/exhaustive-deps }, []) return ( -
+
-
    - {conversations.length > 0 ? ( - conversations + {conversations.length > 0 ? ( +
      + {conversations .filter( (e) => searchText.trim() === '' || e.name?.toLowerCase().includes(searchText.toLowerCase().trim()) ) - .map((convo) => ( + .map((convo, i) => ( - )) - ) : ( - - )} -
    + ))} +
+ ) : ( + + )}
) } diff --git a/web/app/_components/InputToolbar/index.tsx b/web/app/_components/InputToolbar/index.tsx index 4895d90cf..722073455 100644 --- a/web/app/_components/InputToolbar/index.tsx +++ b/web/app/_components/InputToolbar/index.tsx @@ -43,15 +43,7 @@ const InputToolbar: React.FC = () => { } if (!activeConvoId) { - return ( -
- } - /> -
- ) + return null } if ( (activeConvoId && inputState === 'model-mismatch') || @@ -86,7 +78,7 @@ const InputToolbar: React.FC = () => { if (conversations.length > 0) return ( -
+
{currentConvoState?.error && (
@@ -94,13 +86,13 @@ const InputToolbar: React.FC = () => {
)} -
+ {/*
} /> -
+
*/} {/* My text input */}
diff --git a/web/app/_components/LeftHeaderAction/index.tsx b/web/app/_components/LeftHeaderAction/index.tsx index 9c99e3260..f24f41122 100644 --- a/web/app/_components/LeftHeaderAction/index.tsx +++ b/web/app/_components/LeftHeaderAction/index.tsx @@ -1,23 +1,42 @@ 'use client' -import React from 'react' +import React, { useContext } from 'react' import SecondaryButton from '../SecondaryButton' -import { useSetAtom } from 'jotai' +import { useSetAtom, useAtomValue } from 'jotai' import { MainViewState, setMainViewStateAtom, } from '@helpers/atoms/MainView.atom' import { MagnifyingGlassIcon, PlusIcon } from '@heroicons/react/24/outline' +import useCreateConversation from '@hooks/useCreateConversation' import { useGetDownloadedModels } from '@hooks/useGetDownloadedModels' +import { Button } from '@uikit' +import { activeAssistantModelAtom } from '@helpers/atoms/Model.atom' +import { showingModalNoActiveModel } from '@helpers/atoms/Modal.atom' +import { + FeatureToggleContext, +} from '@helpers/FeatureToggleWrapper' const LeftHeaderAction: React.FC = () => { const setMainView = useSetAtom(setMainViewStateAtom) const { downloadedModels } = useGetDownloadedModels() + const activeModel = useAtomValue(activeAssistantModelAtom) + const { requestCreateConvo } = useCreateConversation() + const setShowModalNoActiveModel = useSetAtom(showingModalNoActiveModel) + const { experimentalFeatureEnabed } = useContext(FeatureToggleContext) const onExploreClick = () => { setMainView(MainViewState.ExploreModel) } + const onNewConversationClick = () => { + if (activeModel) { + requestCreateConvo(activeModel) + } else { + setShowModalNoActiveModel(true) + } + } + const onCreateBotClicked = () => { if (downloadedModels.length === 0) { alert('You need to download at least one model to create a bot.') @@ -27,19 +46,30 @@ const LeftHeaderAction: React.FC = () => { } return ( -
- } - /> - {/* } - /> */} +
+
+ } + /> + {experimentalFeatureEnabed && ( + } + /> + )} +
+
) } diff --git a/web/app/_components/Preferences.tsx b/web/app/_components/Preferences.tsx index 7870cbf2d..5cb10fc77 100644 --- a/web/app/_components/Preferences.tsx +++ b/web/app/_components/Preferences.tsx @@ -1,9 +1,6 @@ 'use client' import { useEffect, useRef, useState } from 'react' -import { - plugins, - extensionPoints, -} from '@/../../electron/core/plugin-manager/execution/index' +import { plugins, extensionPoints } from '@plugin' import { ChartPieIcon, CommandLineIcon, @@ -13,7 +10,7 @@ import { import { MagnifyingGlassIcon } from '@heroicons/react/20/solid' import classNames from 'classnames' import { DataService, PluginService, preferences } from '@janhq/core' -import { execute } from '../../../electron/core/plugin-manager/execution/extension-manager' +import { execute } from '@plugin/extension-manager' import LoadingIndicator from './LoadingIndicator' import { executeSerial } from '@services/pluginService' @@ -33,7 +30,7 @@ export const Preferences = () => { * Loads the plugin catalog module from a CDN and sets it as the plugin catalog state. */ useEffect(() => { - executeSerial(DataService.GetPluginManifest).then((data) => { + executeSerial(DataService.GetPluginManifest).then((data: any) => { setPluginCatalog(data) }) }, []) @@ -52,7 +49,7 @@ export const Preferences = () => { if (extensionPoints.get('experimentComponent')) { const components = await Promise.all( - extensionPoints.execute('experimentComponent') + extensionPoints.execute('experimentComponent', {}) ) if (components.length > 0) { setIsTestAvailable(true) @@ -67,7 +64,7 @@ export const Preferences = () => { if (extensionPoints.get('PluginPreferences')) { const data = await Promise.all( - extensionPoints.execute('PluginPreferences') + extensionPoints.execute('PluginPreferences', {}) ) setPreferenceItems(Array.isArray(data) ? data : []) Promise.all( @@ -149,7 +146,7 @@ export const Preferences = () => { } if (extensionPoints.get(PluginService.OnPreferencesUpdate)) timeout = setTimeout( - () => execute(PluginService.OnPreferencesUpdate), + () => execute(PluginService.OnPreferencesUpdate, {}), 100 ) } diff --git a/web/app/_components/ProgressBar/index.tsx b/web/app/_components/ProgressBar/index.tsx index 8981f671a..66f1e1f72 100644 --- a/web/app/_components/ProgressBar/index.tsx +++ b/web/app/_components/ProgressBar/index.tsx @@ -6,7 +6,7 @@ type Props = { } const ProgressBar: React.FC = ({ used, total }) => ( -
+
Updating diff --git a/web/app/_components/SecondaryButton/index.tsx b/web/app/_components/SecondaryButton/index.tsx index 2491edeba..be1d85146 100644 --- a/web/app/_components/SecondaryButton/index.tsx +++ b/web/app/_components/SecondaryButton/index.tsx @@ -16,7 +16,13 @@ const SecondaryButton: React.FC = ({ className, icon, }) => ( - diff --git a/web/app/_components/SidebarFooter/index.tsx b/web/app/_components/SidebarFooter/index.tsx index 42d7c3f3b..1b2430a5f 100644 --- a/web/app/_components/SidebarFooter/index.tsx +++ b/web/app/_components/SidebarFooter/index.tsx @@ -6,14 +6,14 @@ const SidebarFooter: React.FC = () => ( - window.electronAPI?.openExternalUrl('https://discord.gg/AsJ8krTT3N') + window.coreAPI?.openExternalUrl('https://discord.gg/AsJ8krTT3N') } className="flex-1" /> - window.electronAPI?.openExternalUrl('https://twitter.com/janhq_') + window.coreAPI?.openExternalUrl('https://twitter.com/janhq_') } className="flex-1" /> diff --git a/web/app/_components/SimpleTextMessage/index.tsx b/web/app/_components/SimpleTextMessage/index.tsx index 6085ae0fb..d38a507c4 100644 --- a/web/app/_components/SimpleTextMessage/index.tsx +++ b/web/app/_components/SimpleTextMessage/index.tsx @@ -51,7 +51,7 @@ const SimpleTextMessage: React.FC = ({ return (
{ const viewState = useAtomValue(getMainViewStateAtom) @@ -53,7 +56,10 @@ const Page: React.FC = () => { break } - return {children} + return + {children} +
+
} export default Page diff --git a/web/containers/BottomBar/index.tsx b/web/containers/BottomBar/index.tsx index 3b927c44e..fedddcfdb 100644 --- a/web/containers/BottomBar/index.tsx +++ b/web/containers/BottomBar/index.tsx @@ -6,6 +6,8 @@ import { modelDownloadStateAtom } from '@helpers/atoms/DownloadState.atom' import { formatDownloadPercentage } from '@utils/converter' import { activeAssistantModelAtom, stateModel } from '@helpers/atoms/Model.atom' import useGetAppVersion from '@hooks/useGetAppVersion' +import ProgressBar from '@/_components/ProgressBar' +import { appDownloadProgress } from '@helpers/JotaiWrapper' const BottomBar = () => { const activeModel = useAtomValue(activeAssistantModelAtom) @@ -13,6 +15,7 @@ const BottomBar = () => { const { ram, cpu } = useGetSystemResources() const modelDownloadStates = useAtomValue(modelDownloadStateAtom) const appVersion = useGetAppVersion() + const progress = useAtomValue(appDownloadProgress) const downloadStates: DownloadState[] = [] for (const [, value] of Object.entries(modelDownloadStates)) { @@ -21,7 +24,22 @@ const BottomBar = () => { return (
-
+
+
+ {progress && progress >= 0 ? ( + + ) : null} + {downloadStates.length > 0 && ( + + )} +
+ {stateModelStartStop.state === 'start' && stateModelStartStop.loading && ( { {!stateModelStartStop.loading && ( )} - {downloadStates.length > 0 && ( - - )}
-

Jan v{appVersion?.version ?? ''}

+

+ Jan v{appVersion?.version ?? ''} +

) diff --git a/web/containers/Providers/index.tsx b/web/containers/Providers/index.tsx index 623ef9a5e..897e5370b 100644 --- a/web/containers/Providers/index.tsx +++ b/web/containers/Providers/index.tsx @@ -7,15 +7,15 @@ import JotaiWrapper from '@helpers/JotaiWrapper' import { ModalWrapper } from '@helpers/ModalWrapper' import { useEffect, useState } from 'react' import CompactLogo from '@containers/Logo/CompactLogo' -import { - setup, - plugins, - activationPoints, - extensionPoints, -} from '../../../electron/core/plugin-manager/execution/index' +import { setup, plugins, activationPoints, extensionPoints } from '@plugin' import EventListenerWrapper from '@helpers/EventListenerWrapper' import { setupCoreServices } from '@services/coreService' -import { executeSerial, isCorePluginInstalled, setupBasePlugins } from '@services/pluginService' +import { + executeSerial, + isCorePluginInstalled, + setupBasePlugins, +} from '@services/pluginService' +import { FeatureToggleWrapper } from '@helpers/FeatureToggleWrapper' const Providers = (props: PropsWithChildren) => { const [setupCore, setSetupCore] = useState(false) @@ -57,7 +57,7 @@ const Providers = (props: PropsWithChildren) => { useEffect(() => { if (setupCore) { // Electron - if (window && window.electronAPI) { + if (window && window.coreAPI) { setupPE() } else { // Host @@ -71,9 +71,11 @@ const Providers = (props: PropsWithChildren) => { {setupCore && ( {activated ? ( - - {children} - + + + {children} + + ) : (
diff --git a/web/containers/Sidebar/Left.tsx b/web/containers/Sidebar/Left.tsx index 5510538db..62b298543 100644 --- a/web/containers/Sidebar/Left.tsx +++ b/web/containers/Sidebar/Left.tsx @@ -1,4 +1,4 @@ -import React from 'react' +import React, { useContext } from 'react' import { useAtomValue, useSetAtom } from 'jotai' import { MainViewState, @@ -20,6 +20,9 @@ import { twMerge } from 'tailwind-merge' import { showingBotListModalAtom } from '@helpers/atoms/Modal.atom' import useGetBots from '@hooks/useGetBots' import { useUserConfigs } from '@hooks/useUserConfigs' +import { + FeatureToggleContext, +} from '@helpers/FeatureToggleWrapper' export const SidebarLeft = () => { const [config] = useUserConfigs() @@ -28,6 +31,7 @@ export const SidebarLeft = () => { const setBotListModal = useSetAtom(showingBotListModalAtom) const { downloadedModels } = useGetDownloadedModels() const { getAllBots } = useGetBots() + const { experimentalFeatureEnabed } = useContext(FeatureToggleContext) const onMenuClick = (mainViewState: MainViewState) => { if (currentState === mainViewState) return @@ -88,18 +92,21 @@ export const SidebarLeft = () => { icon: , state: MainViewState.MyModel, }, - // { - // name: 'Bot', - // icon: , - // state: MainViewState.CreateBot, - // }, + ...(experimentalFeatureEnabed + ? [ + { + name: 'Bot', + icon: , + state: MainViewState.CreateBot, + }, + ] + : []), { name: 'Settings', icon: , state: MainViewState.Setting, }, ] - return ( { config.sidebarLeftExpand ? 'items-start' : 'items-center' )} > - {menus.map((menu, i) => { - const isActive = currentState === menu.state - const isBotMenu = menu.name === 'Bot' - return ( -
- - {isActive ? ( - - ) : null} -
- ) - })} + {menu.icon} + + {menu.name} + + + {isActive ? ( + + ) : null} +
+ ) + })}
{