diff --git a/pkg/db/media_server.go b/pkg/db/media_server.go index e6cb67f..f2ca72e 100644 --- a/pkg/db/media_server.go +++ b/pkg/db/media_server.go @@ -58,6 +58,19 @@ func NewMediaServerDB(dbPath string) (*MediaServerDB, error) { return &MediaServerDB{db: db}, nil } +// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器 +func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) { + var ms models.MediaServerResponse + err := m.db.QueryRow(` + SELECT id, name, type, ip, port, username, password, secret, is_default, created_at + FROM media_servers WHERE name = ? AND ip = ? + `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) + if err != nil { + return nil, err + } + return &ms, nil +} + func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { _, err := m.db.Exec(` INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default) @@ -66,6 +79,24 @@ func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, us return err } +// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新) +func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { + // 检查是否已存在 + existing, err := m.GetMediaServerByNameAndIP(name, ip) + if err == nil && existing != nil { + // 已存在,更新记录 + _, err = m.db.Exec(` + UPDATE media_servers + SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ? + WHERE name = ? AND ip = ? + `, serverType, port, username, password, secret, isDefault, name, ip) + return err + } + + // 不存在,插入新记录 + return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault) +} + func (m *MediaServerDB) DeleteMediaServer(id int) error { _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id) return err diff --git a/pkg/db/media_server_test.go b/pkg/db/media_server_test.go new file mode 100644 index 0000000..1d9fcdf --- /dev/null +++ b/pkg/db/media_server_test.go @@ -0,0 +1,107 @@ +package db + +import ( + "os" + "testing" +) + +func TestAddOrUpdateMediaServer(t *testing.T) { + // 创建临时数据库文件 + dbPath := "./test_media_servers.db" + defer os.Remove(dbPath) + + // 创建数据库实例 + db, err := NewMediaServerDB(dbPath) + if err != nil { + t.Fatalf("Failed to create database: %v", err) + } + defer db.Close() + + // 测试第一次添加 + err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1) + if err != nil { + t.Fatalf("Failed to add media server: %v", err) + } + + // 验证添加成功 + servers, err := db.ListMediaServers() + if err != nil { + t.Fatalf("Failed to list media servers: %v", err) + } + if len(servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(servers)) + } + if servers[0].Name != "Default" || servers[0].IP != "192.168.1.100" { + t.Fatalf("Server data mismatch: %+v", servers[0]) + } + + // 测试重复添加(应该更新而不是插入新记录) + err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "admin", "password", "secret", 1) + if err != nil { + t.Fatalf("Failed to update media server: %v", err) + } + + // 验证没有重复记录 + servers, err = db.ListMediaServers() + if err != nil { + t.Fatalf("Failed to list media servers: %v", err) + } + if len(servers) != 1 { + t.Fatalf("Expected 1 server after update, got %d", len(servers)) + } + if servers[0].Username != "admin" || servers[0].Password != "password" { + t.Fatalf("Server update failed: %+v", servers[0]) + } + + // 测试多次调用(模拟容器重启) + for i := 0; i < 5; i++ { + err = db.AddOrUpdateMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1) + if err != nil { + t.Fatalf("Failed on iteration %d: %v", i, err) + } + } + + // 验证仍然只有一条记录 + servers, err = db.ListMediaServers() + if err != nil { + t.Fatalf("Failed to list media servers: %v", err) + } + if len(servers) != 1 { + t.Fatalf("Expected 1 server after multiple restarts, got %d", len(servers)) + } + + t.Log("Test passed: No duplicate servers created on restart") +} + +func TestAddMediaServerDuplicates(t *testing.T) { + // 创建临时数据库文件 + dbPath := "./test_media_servers_dup.db" + defer os.Remove(dbPath) + + // 创建数据库实例 + db, err := NewMediaServerDB(dbPath) + if err != nil { + t.Fatalf("Failed to create database: %v", err) + } + defer db.Close() + + // 使用旧的 AddMediaServer 方法测试重复添加问题 + for i := 0; i < 3; i++ { + err = db.AddMediaServer("Default", "SRS", "192.168.1.100", 1985, "", "", "", 1) + if err != nil { + t.Fatalf("Failed to add media server: %v", err) + } + } + + // 验证会产生重复记录 + servers, err := db.ListMediaServers() + if err != nil { + t.Fatalf("Failed to list media servers: %v", err) + } + if len(servers) != 3 { + t.Fatalf("Expected 3 duplicate servers with old method, got %d", len(servers)) + } + + t.Log("Test confirmed: Old AddMediaServer method creates duplicates") +} + diff --git a/pkg/service/uas.go b/pkg/service/uas.go index 9c79485..ace8f09 100644 --- a/pkg/service/uas.go +++ b/pkg/service/uas.go @@ -86,7 +86,7 @@ func (s *UAS) startSipServer(agent *sipgo.UserAgent, ctx context.Context, r0 int candidate := os.Getenv("CANDIDATE") if candidate != "" { - MediaDB.AddMediaServer("Default", "SRS", candidate, 1985, "", "", "", 1) + MediaDB.AddOrUpdateMediaServer("Default", "SRS", candidate, 1985, "", "", "", 1) } return nil }