Merge pull request #409 from janhq/393-nitro-new-version-integration-inference-params-update-config-in-runtime

feat: Support for nitro release 0.1.2
This commit is contained in:
namvuong 2023-10-19 20:59:16 -07:00 committed by GitHub
commit 28f8b8e032
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 171 additions and 96 deletions

View File

@ -1,2 +1,3 @@
declare const PLUGIN_NAME: string; declare const PLUGIN_NAME: string;
declare const MODULE_PATH: string; declare const MODULE_PATH: string;
declare const INFERENCE_URL: string;

View File

@ -1,17 +1,23 @@
import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core"; import {
EventName,
InferenceService,
NewMessageRequest,
PluginService,
events,
store,
invokePluginFunc,
} from "@janhq/core";
import { Observable } from "rxjs"; import { Observable } from "rxjs";
const inferenceUrl = "http://localhost:3928/llama/chat_completion"; const initModel = async (product) => invokePluginFunc(MODULE_PATH, "initModel", product);
const initModel = async (product) => core.invokePluginFunc(MODULE_PATH, "initModel", product);
const stopModel = () => { const stopModel = () => {
core.invokePluginFunc(MODULE_PATH, "killSubprocess"); invokePluginFunc(MODULE_PATH, "killSubprocess");
}; };
function requestInference(recentMessages: any[]): Observable<string> { function requestInference(recentMessages: any[]): Observable<string> {
return new Observable((subscriber) => { return new Observable((subscriber) => {
fetch(inferenceUrl, { fetch(INFERENCE_URL, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -5,8 +5,8 @@ const fs = require("fs");
const tcpPortUsed = require("tcp-port-used"); const tcpPortUsed = require("tcp-port-used");
const { killPortProcess } = require("kill-port-process"); const { killPortProcess } = require("kill-port-process");
let subprocess = null;
const PORT = 3928; const PORT = 3928;
let subprocess = null;
const initModel = (fileName) => { const initModel = (fileName) => {
return ( return (
@ -15,9 +15,7 @@ const initModel = (fileName) => {
reject("Model not found, please download again."); reject("Model not found, please download again.");
} }
if (subprocess) { if (subprocess) {
console.error( console.error("A subprocess is already running. Attempt to kill then reinit.");
"A subprocess is already running. Attempt to kill then reinit."
);
killSubprocess(); killSubprocess();
} }
resolve(fileName); resolve(fileName);
@ -32,35 +30,13 @@ const initModel = (fileName) => {
// Spawn Nitro subprocess to load model // Spawn Nitro subprocess to load model
.then(() => { .then(() => {
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
// Read the existing config
const configFilePath = path.join(binaryFolder, "config", "config.json");
let config: any = {};
if (fs.existsSync(configFilePath)) {
const rawData = fs.readFileSync(configFilePath, "utf-8");
config = JSON.parse(rawData);
}
// Update the llama_model_path
if (!config.custom_config) {
config.custom_config = {};
}
const modelPath = path.join(app.getPath("userData"), fileName);
config.custom_config.llama_model_path = modelPath;
// Write the updated config back to the file
fs.writeFileSync(configFilePath, JSON.stringify(config, null, 4));
let binaryName; let binaryName;
if (process.platform === "win32") { if (process.platform === "win32") {
binaryName = "nitro_windows_amd64.exe"; binaryName = "nitro_windows_amd64_cuda.exe";
} else if (process.platform === "darwin") { } else if (process.platform === "darwin") {
// Mac OS platform // Mac OS platform
binaryName = binaryName = process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_intel";
process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_amd64";
} else { } else {
// Linux // Linux
binaryName = "nitro_linux_amd64_cuda"; // For other platforms binaryName = "nitro_linux_amd64_cuda"; // For other platforms
@ -69,8 +45,7 @@ const initModel = (fileName) => {
const binaryPath = path.join(binaryFolder, binaryName); const binaryPath = path.join(binaryFolder, binaryName);
// Execute the binary // Execute the binary
subprocess = spawn(binaryPath, { cwd: binaryFolder });
subprocess = spawn(binaryPath, [configFilePath], { cwd: binaryFolder });
// Handle subprocess output // Handle subprocess output
subprocess.stdout.on("data", (data) => { subprocess.stdout.on("data", (data) => {
@ -88,7 +63,29 @@ const initModel = (fileName) => {
}) })
.then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000)) .then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000))
.then(() => { .then(() => {
const llama_model_path = path.join(app.getPath("userData"), fileName);
const config = {
llama_model_path,
ctx_len: 2048,
ngl: 100,
embedding: true, // Always enable embedding mode on
};
// Load model config
return fetch(`http://127.0.0.1:${PORT}/inferences/llamacpp/loadmodel`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(config),
});
})
.then((res) => {
if (res.ok) {
return {}; return {};
}
throw new Error("Nitro: Model failed to load.");
}) })
.catch((err) => { .catch((err) => {
return { error: err }; return { error: err };

View File

@ -1,13 +0,0 @@
{
"listeners": [
{
"address": "0.0.0.0",
"port": 3928
}
],
"custom_config": {
"llama_model_path": "",
"ctx_len": 2048,
"ngl": 100
}
}

View File

@ -132,6 +132,13 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]); dst[tpig] = max(0.0f, src0[tpig]);
} }
kernel void kernel_sqr(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
constant float GELU_COEF_A = 0.044715f; constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@ -340,6 +347,7 @@ kernel void kernel_rms_norm(
uint ntg[[threads_per_threadgroup]]) { uint ntg[[threads_per_threadgroup]]) {
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
device const float * x_scalar = (device const float *) x; device const float * x_scalar = (device const float *) x;
float4 sumf = 0; float4 sumf = 0;
float all_sum = 0; float all_sum = 0;
@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32 // broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) { for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) { if (tpitg < i) {
@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
} }
} }
if (tpitg == 0) { if (tpitg == 0) {
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} for (int i = 4 * (ne00 / 4); i < ne00; i++) {
sum[0] += x_scalar[i];
}
sum[0] /= ne00; sum[0] /= ne00;
} }
@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
y[i00] = x[i00] * scale; y[i00] = x[i00] * scale;
} }
if (tpitg == 0) { if (tpitg == 0) {
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
y_scalar[i00] = x_scalar[i00] * scale;
}
} }
} }
@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
uint3 tgpig, uint tiisg, uint sgitg) { uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr; const int first_row = (r0 * nsg + sgitg) * nr;
const uint offset0 = first_row * nb + im/gqa*(nb*ne0); const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache float yl[16]; // src1 vector cache
float sumf[nr] = {0.f}; float sumf[nr] = {0.f};
const int ix = tiisg/2; const int ix = (tiisg/2);
const int il = 8*(tiisg%2); const int il = (tiisg%2)*8;
device const float * yb = y + ix * QK4_0 + il; device const float * yb = y + ix * QK4_0 + il;
@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
sumy += yb[i] + yb[i+1]; sumy += yb[i] + yb[i+1];
yl[i+0] = yb[i+ 0]; yl[i+0] = yb[i+ 0];
yl[i+1] = yb[i+ 1]/256.f; yl[i+1] = yb[i+ 1]/256.f;
sumy += yb[i+16] + yb[i+17]; sumy += yb[i+16] + yb[i+17];
yl[i+8] = yb[i+16]/16.f; yl[i+8] = yb[i+16]/16.f;
yl[i+9] = yb[i+17]/4096.f; yl[i+9] = yb[i+17]/4096.f;
@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
for (int row = 0; row < nr; ++row) { for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]); const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < ne01) { if (tiisg == 0 && first_row + row < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
} }
} }
} }
kernel void kernel_mul_mat_q4_0_f32( kernel void kernel_mul_mv_q4_0_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -488,7 +507,7 @@ kernel void kernel_mul_mat_q4_0_f32(
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mv_q4_1_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
#define NB_Q8_0 8 #define NB_Q8_0 8
kernel void kernel_mul_mat_q8_0_f32( kernel void kernel_mul_mv_q8_0_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
#define N_F32_F32 4 #define N_F32_F32 4
kernel void kernel_mul_mat_f32_f32( kernel void kernel_mul_mv_f32_f32(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
} }
} }
kernel void kernel_mul_mat_f16_f32_1row( kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
#define N_F16_F32 4 #define N_F16_F32 4
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mv_f16_f32(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
} }
// Assumes row size (ne00) is a multiple of 4 // Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mat_f16_f32_l4( kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1098,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
} }
} }
kernel void kernel_concat(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig.z;
const int64_t i02 = tgpig.y;
const int64_t i01 = tgpig.x;
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i02 < ne02) {
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
src0_ptr += ntg.x*nb00;
} else {
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
src1_ptr += ntg.x*nb10;
}
dst_ptr += ntg.x*nb0;
}
}
//============================================ k-quants ====================================================== //============================================ k-quants ======================================================
#ifndef QK_K #ifndef QK_K
@ -1190,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
//====================================== dot products ========================= //====================================== dot products =========================
kernel void kernel_mul_mat_q2_K_f32( kernel void kernel_mul_mv_q2_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1334,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
} }
#if QK_K == 256 #if QK_K == 256
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mv_q3_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1486,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mv_q3_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1557,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
#endif #endif
#if QK_K == 256 #if QK_K == 256
kernel void kernel_mul_mat_q4_K_f32( kernel void kernel_mul_mv_q4_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1663,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
} }
} }
#else #else
kernel void kernel_mul_mat_q4_K_f32( kernel void kernel_mul_mv_q4_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1752,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
} }
#endif #endif
kernel void kernel_mul_mat_q5_K_f32( kernel void kernel_mul_mv_q5_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -1925,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
} }
kernel void kernel_mul_mat_q6_K_f32( kernel void kernel_mul_mv_q6_K_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
@ -2263,7 +2338,7 @@ kernel void kernel_get_rows(
} }
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32 #define BLOCK_SIZE_K 32
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@ -2300,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
const uint r0 = tgpig.y; const uint r0 = tgpig.y;
const uint r1 = tgpig.x; const uint r1 = tgpig.x;
const uint im = tgpig.z; const uint im = tgpig.z;
// if this block is of 64x32 shape or smaller // if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix // a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@ -2330,22 +2407,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
half4x4 temp_a; half4x4 temp_a;
dequantize_func(x, il, temp_a); dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16) #pragma unroll(16)
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
} }
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
= *((device float2x4 *)y); *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
il = (il + 2 < nl) ? il + 2 : il % 2; il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x; x = (il < 2) ? x + (2+nl-1)/nl : x;
y += BLOCK_SIZE_K; y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products // load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4) #pragma unroll(4)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4) #pragma unroll(4)
@ -2360,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8) #pragma unroll(8)
for (int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@ -2368,7 +2450,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
@ -2383,7 +2465,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg == 0) { if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {

View File

@ -19,6 +19,7 @@ module.exports = {
new webpack.DefinePlugin({ new webpack.DefinePlugin({
PLUGIN_NAME: JSON.stringify(packageJson.name), PLUGIN_NAME: JSON.stringify(packageJson.name),
MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`), MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
INFERENCE_URL: JSON.stringify("http://127.0.0.1:3928/inferences/llamacpp/chat_completion"),
}), }),
], ],
output: { output: {