Merge pull request #409 from janhq/393-nitro-new-version-integration-inference-params-update-config-in-runtime
feat: Support for nitro release 0.1.2
This commit is contained in:
commit
28f8b8e032
1
plugins/inference-plugin/@types/global.d.ts
vendored
1
plugins/inference-plugin/@types/global.d.ts
vendored
@ -1,2 +1,3 @@
|
|||||||
declare const PLUGIN_NAME: string;
|
declare const PLUGIN_NAME: string;
|
||||||
declare const MODULE_PATH: string;
|
declare const MODULE_PATH: string;
|
||||||
|
declare const INFERENCE_URL: string;
|
||||||
|
|||||||
@ -1,17 +1,23 @@
|
|||||||
import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core";
|
import {
|
||||||
|
EventName,
|
||||||
|
InferenceService,
|
||||||
|
NewMessageRequest,
|
||||||
|
PluginService,
|
||||||
|
events,
|
||||||
|
store,
|
||||||
|
invokePluginFunc,
|
||||||
|
} from "@janhq/core";
|
||||||
import { Observable } from "rxjs";
|
import { Observable } from "rxjs";
|
||||||
|
|
||||||
const inferenceUrl = "http://localhost:3928/llama/chat_completion";
|
const initModel = async (product) => invokePluginFunc(MODULE_PATH, "initModel", product);
|
||||||
|
|
||||||
const initModel = async (product) => core.invokePluginFunc(MODULE_PATH, "initModel", product);
|
|
||||||
|
|
||||||
const stopModel = () => {
|
const stopModel = () => {
|
||||||
core.invokePluginFunc(MODULE_PATH, "killSubprocess");
|
invokePluginFunc(MODULE_PATH, "killSubprocess");
|
||||||
};
|
};
|
||||||
|
|
||||||
function requestInference(recentMessages: any[]): Observable<string> {
|
function requestInference(recentMessages: any[]): Observable<string> {
|
||||||
return new Observable((subscriber) => {
|
return new Observable((subscriber) => {
|
||||||
fetch(inferenceUrl, {
|
fetch(INFERENCE_URL, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|||||||
@ -5,8 +5,8 @@ const fs = require("fs");
|
|||||||
const tcpPortUsed = require("tcp-port-used");
|
const tcpPortUsed = require("tcp-port-used");
|
||||||
const { killPortProcess } = require("kill-port-process");
|
const { killPortProcess } = require("kill-port-process");
|
||||||
|
|
||||||
let subprocess = null;
|
|
||||||
const PORT = 3928;
|
const PORT = 3928;
|
||||||
|
let subprocess = null;
|
||||||
|
|
||||||
const initModel = (fileName) => {
|
const initModel = (fileName) => {
|
||||||
return (
|
return (
|
||||||
@ -15,9 +15,7 @@ const initModel = (fileName) => {
|
|||||||
reject("Model not found, please download again.");
|
reject("Model not found, please download again.");
|
||||||
}
|
}
|
||||||
if (subprocess) {
|
if (subprocess) {
|
||||||
console.error(
|
console.error("A subprocess is already running. Attempt to kill then reinit.");
|
||||||
"A subprocess is already running. Attempt to kill then reinit."
|
|
||||||
);
|
|
||||||
killSubprocess();
|
killSubprocess();
|
||||||
}
|
}
|
||||||
resolve(fileName);
|
resolve(fileName);
|
||||||
@ -32,35 +30,13 @@ const initModel = (fileName) => {
|
|||||||
// Spawn Nitro subprocess to load model
|
// Spawn Nitro subprocess to load model
|
||||||
.then(() => {
|
.then(() => {
|
||||||
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
|
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
|
||||||
|
|
||||||
// Read the existing config
|
|
||||||
const configFilePath = path.join(binaryFolder, "config", "config.json");
|
|
||||||
let config: any = {};
|
|
||||||
if (fs.existsSync(configFilePath)) {
|
|
||||||
const rawData = fs.readFileSync(configFilePath, "utf-8");
|
|
||||||
config = JSON.parse(rawData);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the llama_model_path
|
|
||||||
if (!config.custom_config) {
|
|
||||||
config.custom_config = {};
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelPath = path.join(app.getPath("userData"), fileName);
|
|
||||||
|
|
||||||
config.custom_config.llama_model_path = modelPath;
|
|
||||||
|
|
||||||
// Write the updated config back to the file
|
|
||||||
fs.writeFileSync(configFilePath, JSON.stringify(config, null, 4));
|
|
||||||
|
|
||||||
let binaryName;
|
let binaryName;
|
||||||
|
|
||||||
if (process.platform === "win32") {
|
if (process.platform === "win32") {
|
||||||
binaryName = "nitro_windows_amd64.exe";
|
binaryName = "nitro_windows_amd64_cuda.exe";
|
||||||
} else if (process.platform === "darwin") {
|
} else if (process.platform === "darwin") {
|
||||||
// Mac OS platform
|
// Mac OS platform
|
||||||
binaryName =
|
binaryName = process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_intel";
|
||||||
process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_amd64";
|
|
||||||
} else {
|
} else {
|
||||||
// Linux
|
// Linux
|
||||||
binaryName = "nitro_linux_amd64_cuda"; // For other platforms
|
binaryName = "nitro_linux_amd64_cuda"; // For other platforms
|
||||||
@ -69,8 +45,7 @@ const initModel = (fileName) => {
|
|||||||
const binaryPath = path.join(binaryFolder, binaryName);
|
const binaryPath = path.join(binaryFolder, binaryName);
|
||||||
|
|
||||||
// Execute the binary
|
// Execute the binary
|
||||||
|
subprocess = spawn(binaryPath, { cwd: binaryFolder });
|
||||||
subprocess = spawn(binaryPath, [configFilePath], { cwd: binaryFolder });
|
|
||||||
|
|
||||||
// Handle subprocess output
|
// Handle subprocess output
|
||||||
subprocess.stdout.on("data", (data) => {
|
subprocess.stdout.on("data", (data) => {
|
||||||
@ -88,7 +63,29 @@ const initModel = (fileName) => {
|
|||||||
})
|
})
|
||||||
.then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000))
|
.then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000))
|
||||||
.then(() => {
|
.then(() => {
|
||||||
|
const llama_model_path = path.join(app.getPath("userData"), fileName);
|
||||||
|
|
||||||
|
const config = {
|
||||||
|
llama_model_path,
|
||||||
|
ctx_len: 2048,
|
||||||
|
ngl: 100,
|
||||||
|
embedding: true, // Always enable embedding mode on
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load model config
|
||||||
|
return fetch(`http://127.0.0.1:${PORT}/inferences/llamacpp/loadmodel`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(config),
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.then((res) => {
|
||||||
|
if (res.ok) {
|
||||||
return {};
|
return {};
|
||||||
|
}
|
||||||
|
throw new Error("Nitro: Model failed to load.");
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
return { error: err };
|
return { error: err };
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"listeners": [
|
|
||||||
{
|
|
||||||
"address": "0.0.0.0",
|
|
||||||
"port": 3928
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"custom_config": {
|
|
||||||
"llama_model_path": "",
|
|
||||||
"ctx_len": 2048,
|
|
||||||
"ngl": 100
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -132,6 +132,13 @@ kernel void kernel_relu(
|
|||||||
dst[tpig] = max(0.0f, src0[tpig]);
|
dst[tpig] = max(0.0f, src0[tpig]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_sqr(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] * src0[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
constant float GELU_COEF_A = 0.044715f;
|
constant float GELU_COEF_A = 0.044715f;
|
||||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
@ -340,6 +347,7 @@ kernel void kernel_rms_norm(
|
|||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||||
device const float * x_scalar = (device const float *) x;
|
device const float * x_scalar = (device const float *) x;
|
||||||
|
|
||||||
float4 sumf = 0;
|
float4 sumf = 0;
|
||||||
float all_sum = 0;
|
float all_sum = 0;
|
||||||
|
|
||||||
@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
// broadcast, simd group number is ntg / 32
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
if (tpitg < i) {
|
if (tpitg < i) {
|
||||||
@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||||
|
sum[0] += x_scalar[i];
|
||||||
|
}
|
||||||
sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
|
|||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||||
|
y_scalar[i00] = x_scalar[i00] * scale;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
||||||
uint3 tgpig, uint tiisg, uint sgitg) {
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
||||||
const int nb = ne00/QK4_0;
|
const int nb = ne00/QK4_0;
|
||||||
|
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * nsg + sgitg) * nr;
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||||
|
|
||||||
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
||||||
|
|
||||||
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
float yl[16]; // src1 vector cache
|
float yl[16]; // src1 vector cache
|
||||||
float sumf[nr] = {0.f};
|
float sumf[nr] = {0.f};
|
||||||
|
|
||||||
const int ix = tiisg/2;
|
const int ix = (tiisg/2);
|
||||||
const int il = 8*(tiisg%2);
|
const int il = (tiisg%2)*8;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK4_0 + il;
|
device const float * yb = y + ix * QK4_0 + il;
|
||||||
|
|
||||||
@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
sumy += yb[i] + yb[i+1];
|
sumy += yb[i] + yb[i+1];
|
||||||
yl[i+0] = yb[i+ 0];
|
yl[i+0] = yb[i+ 0];
|
||||||
yl[i+1] = yb[i+ 1]/256.f;
|
yl[i+1] = yb[i+ 1]/256.f;
|
||||||
|
|
||||||
sumy += yb[i+16] + yb[i+17];
|
sumy += yb[i+16] + yb[i+17];
|
||||||
yl[i+8] = yb[i+16]/16.f;
|
yl[i+8] = yb[i+16]/16.f;
|
||||||
yl[i+9] = yb[i+17]/4096.f;
|
yl[i+9] = yb[i+17]/4096.f;
|
||||||
@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
const float tot = simd_sum(sumf[row]);
|
const float tot = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0 && first_row + row < ne01) {
|
if (tiisg == 0 && first_row + row < ne01) {
|
||||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_0_f32(
|
kernel void kernel_mul_mv_q4_0_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -488,7 +507,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mv_q4_1_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||||||
|
|
||||||
#define NB_Q8_0 8
|
#define NB_Q8_0 8
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q8_0_f32(
|
kernel void kernel_mul_mv_q8_0_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
|
|
||||||
#define N_F32_F32 4
|
#define N_F32_F32 4
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f32_f32(
|
kernel void kernel_mul_mv_f32_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32_1row(
|
kernel void kernel_mul_mv_f16_f32_1row(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|||||||
|
|
||||||
#define N_F16_F32 4
|
#define N_F16_F32 4
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mv_f16_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assumes row size (ne00) is a multiple of 4
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
kernel void kernel_mul_mat_f16_f32_l4(
|
kernel void kernel_mul_mv_f16_f32_l4(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1098,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_concat(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int64_t i03 = tgpig.z;
|
||||||
|
const int64_t i02 = tgpig.y;
|
||||||
|
const int64_t i01 = tgpig.x;
|
||||||
|
|
||||||
|
const int64_t i13 = i03 % ne13;
|
||||||
|
const int64_t i12 = i02 % ne12;
|
||||||
|
const int64_t i11 = i01 % ne11;
|
||||||
|
|
||||||
|
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
||||||
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
||||||
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||||
|
if (i02 < ne02) {
|
||||||
|
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
||||||
|
src0_ptr += ntg.x*nb00;
|
||||||
|
} else {
|
||||||
|
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
||||||
|
src1_ptr += ntg.x*nb10;
|
||||||
|
}
|
||||||
|
dst_ptr += ntg.x*nb0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//============================================ k-quants ======================================================
|
//============================================ k-quants ======================================================
|
||||||
|
|
||||||
#ifndef QK_K
|
#ifndef QK_K
|
||||||
@ -1190,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|||||||
|
|
||||||
//====================================== dot products =========================
|
//====================================== dot products =========================
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q2_K_f32(
|
kernel void kernel_mul_mv_q2_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1334,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
kernel void kernel_mul_mat_q3_K_f32(
|
kernel void kernel_mul_mv_q3_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1486,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
kernel void kernel_mul_mat_q3_K_f32(
|
kernel void kernel_mul_mv_q3_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1557,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
kernel void kernel_mul_mat_q4_K_f32(
|
kernel void kernel_mul_mv_q4_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1663,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
kernel void kernel_mul_mat_q4_K_f32(
|
kernel void kernel_mul_mv_q4_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1752,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q5_K_f32(
|
kernel void kernel_mul_mv_q5_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1925,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q6_K_f32(
|
kernel void kernel_mul_mv_q6_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -2263,7 +2338,7 @@ kernel void kernel_get_rows(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||||
#define BLOCK_SIZE_K 32
|
#define BLOCK_SIZE_K 32
|
||||||
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
||||||
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
||||||
@ -2300,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
const uint r0 = tgpig.y;
|
const uint r0 = tgpig.y;
|
||||||
const uint r1 = tgpig.x;
|
const uint r1 = tgpig.x;
|
||||||
const uint im = tgpig.z;
|
const uint im = tgpig.z;
|
||||||
|
|
||||||
// if this block is of 64x32 shape or smaller
|
// if this block is of 64x32 shape or smaller
|
||||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||||
|
|
||||||
// a thread shouldn't load data outside of the matrix
|
// a thread shouldn't load data outside of the matrix
|
||||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||||
@ -2330,22 +2407,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
half4x4 temp_a;
|
half4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
#pragma unroll(16)
|
#pragma unroll(16)
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||||
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||||
}
|
}
|
||||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
|
||||||
= *((device float2x4 *)y);
|
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||||
|
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||||
y += BLOCK_SIZE_K;
|
y += BLOCK_SIZE_K;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// load matrices from threadgroup memory and conduct outer products
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
@ -2360,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
|
|
||||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
|
|
||||||
#pragma unroll(8)
|
#pragma unroll(8)
|
||||||
for (int i = 0; i < 8; i++){
|
for (int i = 0; i < 8; i++){
|
||||||
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
||||||
@ -2368,7 +2450,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
||||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
||||||
@ -2383,7 +2465,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
||||||
|
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe
Normal file
BIN
plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe
Normal file
Binary file not shown.
@ -19,6 +19,7 @@ module.exports = {
|
|||||||
new webpack.DefinePlugin({
|
new webpack.DefinePlugin({
|
||||||
PLUGIN_NAME: JSON.stringify(packageJson.name),
|
PLUGIN_NAME: JSON.stringify(packageJson.name),
|
||||||
MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||||
|
INFERENCE_URL: JSON.stringify("http://127.0.0.1:3928/inferences/llamacpp/chat_completion"),
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
output: {
|
output: {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user