我使用 vmdot 指定实现了一个简单的矩阵乘算法,验证功能。在运行时出现了 segmentfault 问题,求大佬解答。以及是否有编程手册及程序示例
#include <stdio.h>
#include <stdint.h>
int main() {
int M = 4;
int N = 4;
int K = 8;
int8_t A[M][K];
int8_t B[N][K];
int32_t C[M][N];
int32_t C_baseline[M][N];
// init A,B,C
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
A[i][j] = i;
}
}
for (int i = 0; i < N; i++) {
for (int j = 0; j < K; j++) {
B[i][j] = i;
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C[i][j] = 0;
C_baseline[i][j] = 0;
}
}
call_vmadot(A, B, C);
// after vmadot
printf("after vmadot C:\n");
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
printf("%d ", C[i][j]);
}
printf("\n");
}
printf("\n");
return 0;
}
.align 4
; riscv vmadot martrix mac unit [4,4,8]
; implement call_vmadot(a_base, b_base, c_base) function
.global call_vmadot
call_vmadot:
; save registers
addi sp, sp, -16
sw ra, 0(sp)
sw a0, 4(sp)
sw a1, 8(sp)
sw a2, 12(sp)
; load arguments
lw a0, 0(a0) ; a_base
lw a1, 0(a1) ; b_base
lw a2, 0(a2) ; c_base
; set vector length
vsetvli t0, x0, e8, m1, ta, ma
; load a_base to v0
vle8.v v0, (a0)
; load b_base to v1
vle8.v v1, (a1)
; vmadot.vv v2, v0, v1
vmadot v2, v0, v1
; store c_base to v2
vsetvli t0, x0, e32, m2, ta, ma
vse32.v v2, (a2)
; restore registers
lw ra, 0(sp)
lw a0, 4(sp)
lw a1, 8(sp)
lw a2, 12(sp)
addi sp, sp, 16
ret