fix issues/26
This commit is contained in:
@ -58,6 +58,19 @@ func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
|
||||
return &MediaServerDB{db: db}, nil
|
||||
}
|
||||
|
||||
// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
|
||||
func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
|
||||
var ms models.MediaServerResponse
|
||||
err := m.db.QueryRow(`
|
||||
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
|
||||
FROM media_servers WHERE name = ? AND ip = ?
|
||||
`, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ms, nil
|
||||
}
|
||||
|
||||
func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
|
||||
_, err := m.db.Exec(`
|
||||
INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
|
||||
@ -66,6 +79,24 @@ func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, us
|
||||
return err
|
||||
}
|
||||
|
||||
// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
|
||||
func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
|
||||
// 检查是否已存在
|
||||
existing, err := m.GetMediaServerByNameAndIP(name, ip)
|
||||
if err == nil && existing != nil {
|
||||
// 已存在,更新记录
|
||||
_, err = m.db.Exec(`
|
||||
UPDATE media_servers
|
||||
SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
|
||||
WHERE name = ? AND ip = ?
|
||||
`, serverType, port, username, password, secret, isDefault, name, ip)
|
||||
return err
|
||||
}
|
||||
|
||||
// 不存在,插入新记录
|
||||
return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
|
||||
}
|
||||
|
||||
func (m *MediaServerDB) DeleteMediaServer(id int) error {
|
||||
_, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
|
||||
return err
|
||||
|
||||
107
pkg/db/media_server_test.go
Normal file
107
pkg/db/media_server_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddOrUpdateMediaServer(t *testing.T) {
|
||||
// 创建临时数据库文件
|
||||
dbPath := "./test_media_servers.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
// 创建数据库实例
|
||||
db, err := NewMediaServerDB(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 测试第一次添加
|
||||
err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add media server: %v", err)
|
||||
}
|
||||
|
||||
// 验证添加成功
|
||||
servers, err := db.ListMediaServers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list media servers: %v", err)
|
||||
}
|
||||
if len(servers) != 1 {
|
||||
t.Fatalf("Expected 1 server, got %d", len(servers))
|
||||
}
|
||||
if servers[0].Name != "Default" || servers[0].IP != "192.168.1.100" {
|
||||
t.Fatalf("Server data mismatch: %+v", servers[0])
|
||||
}
|
||||
|
||||
// 测试重复添加(应该更新而不是插入新记录)
|
||||
err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "admin", "password", "secret", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update media server: %v", err)
|
||||
}
|
||||
|
||||
// 验证没有重复记录
|
||||
servers, err = db.ListMediaServers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list media servers: %v", err)
|
||||
}
|
||||
if len(servers) != 1 {
|
||||
t.Fatalf("Expected 1 server after update, got %d", len(servers))
|
||||
}
|
||||
if servers[0].Username != "admin" || servers[0].Password != "password" {
|
||||
t.Fatalf("Server update failed: %+v", servers[0])
|
||||
}
|
||||
|
||||
// 测试多次调用(模拟容器重启)
|
||||
for i := 0; i < 5; i++ {
|
||||
err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed on iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证仍然只有一条记录
|
||||
servers, err = db.ListMediaServers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list media servers: %v", err)
|
||||
}
|
||||
if len(servers) != 1 {
|
||||
t.Fatalf("Expected 1 server after multiple restarts, got %d", len(servers))
|
||||
}
|
||||
|
||||
t.Log("Test passed: No duplicate servers created on restart")
|
||||
}
|
||||
|
||||
func TestAddMediaServerDuplicates(t *testing.T) {
|
||||
// 创建临时数据库文件
|
||||
dbPath := "./test_media_servers_dup.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
// 创建数据库实例
|
||||
db, err := NewMediaServerDB(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 使用旧的 AddMediaServer 方法测试重复添加问题
|
||||
for i := 0; i < 3; i++ {
|
||||
err = db.AddMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add media server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证会产生重复记录
|
||||
servers, err := db.ListMediaServers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list media servers: %v", err)
|
||||
}
|
||||
if len(servers) != 3 {
|
||||
t.Fatalf("Expected 3 duplicate servers with old method, got %d", len(servers))
|
||||
}
|
||||
|
||||
t.Log("Test confirmed: Old AddMediaServer method creates duplicates")
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@ func (s *UAS) startSipServer(agent *sipgo.UserAgent, ctx context.Context, r0 int
|
||||
|
||||
candidate := os.Getenv("CANDIDATE")
|
||||
if candidate != "" {
|
||||
MediaDB.AddMediaServer("Default", "SRS", candidate, 1985, "", "", "", 1)
|
||||
MediaDB.AddOrUpdateMediaServer("Default", "SRS", candidate, 1985, "", "", "", 1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user