From 156f07644d535493bf174a3514eace3cc52b231f Mon Sep 17 00:00:00 2001
From: "haibo.chen" <495810242@qq.com>
Date: Wed, 15 Oct 2025 10:05:52 +0800
Subject: [PATCH] gofmt
---
pkg/config/config_test.go | 3 +-
pkg/db/media_server.go | 304 ++++++++++++-------------
pkg/media/zlm.go | 118 +++++-----
pkg/models/gb28181.go | 212 +++++++++---------
pkg/models/types.go | 160 +++++++-------
pkg/models/types_test.go | 1 -
pkg/service/auth.go | 184 ++++++++--------
pkg/service/auth_test.go | 1 -
pkg/service/cascade.go | 34 +--
pkg/service/inbound.go | 342 ++++++++++++++---------------
pkg/service/ptz.go | 162 +++++++-------
pkg/service/ptz_test.go | 3 +-
pkg/service/stack/request.go | 136 ++++++------
pkg/service/stack/response.go | 82 +++----
pkg/service/uac.go | 244 ++++++++++-----------
pkg/utils/logger.go | 402 +++++++++++++++++-----------------
pkg/utils/utils_test.go | 1 -
tools/main.go | 60 ++---
18 files changed, 1222 insertions(+), 1227 deletions(-)
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index 546a9a6..c2cbbdb 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -133,7 +133,7 @@ func TestLoadConfigInvalid(t *testing.T) {
func TestGetLocalIP(t *testing.T) {
ip, err := GetLocalIP()
-
+
// 在某些环境下可能没有网络接口,所以允许返回错误
if err != nil {
t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err)
@@ -152,4 +152,3 @@ func TestGetLocalIP(t *testing.T) {
t.Logf("Local IP: %s", ip)
}
-
diff --git a/pkg/db/media_server.go b/pkg/db/media_server.go
index f2ca72e..5fa1c9d 100644
--- a/pkg/db/media_server.go
+++ b/pkg/db/media_server.go
@@ -1,152 +1,152 @@
-package db
-
-import (
- "database/sql"
- "sync"
-
- "github.com/ossrs/srs-sip/pkg/models"
- _ "modernc.org/sqlite"
-)
-
-var (
- instance *MediaServerDB
- once sync.Once
-)
-
-type MediaServerDB struct {
- models.MediaServerResponse
- db *sql.DB
-}
-
-// GetInstance 返回 MediaServerDB 的单例实例
-func GetInstance(dbPath string) (*MediaServerDB, error) {
- var err error
- once.Do(func() {
- instance, err = NewMediaServerDB(dbPath)
- })
- if err != nil {
- return nil, err
- }
- return instance, nil
-}
-
-func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
- db, err := sql.Open("sqlite", dbPath)
- if err != nil {
- return nil, err
- }
-
- // 创建媒体服务器表
- _, err = db.Exec(`
- CREATE TABLE IF NOT EXISTS media_servers (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- type TEXT NOT NULL,
- name TEXT NOT NULL,
- ip TEXT NOT NULL,
- port INTEGER NOT NULL,
- username TEXT,
- password TEXT,
- secret TEXT,
- is_default INTEGER NOT NULL DEFAULT 0,
- created_at DATETIME DEFAULT CURRENT_TIMESTAMP
- )
- `)
- if err != nil {
- return nil, err
- }
-
- return &MediaServerDB{db: db}, nil
-}
-
-// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
-func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
- var ms models.MediaServerResponse
- err := m.db.QueryRow(`
- SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
- FROM media_servers WHERE name = ? AND ip = ?
- `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
- if err != nil {
- return nil, err
- }
- return &ms, nil
-}
-
-func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
- _, err := m.db.Exec(`
- INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
- `, name, serverType, ip, port, username, password, secret, isDefault)
- return err
-}
-
-// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
-func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
- // 检查是否已存在
- existing, err := m.GetMediaServerByNameAndIP(name, ip)
- if err == nil && existing != nil {
- // 已存在,更新记录
- _, err = m.db.Exec(`
- UPDATE media_servers
- SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
- WHERE name = ? AND ip = ?
- `, serverType, port, username, password, secret, isDefault, name, ip)
- return err
- }
-
- // 不存在,插入新记录
- return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
-}
-
-func (m *MediaServerDB) DeleteMediaServer(id int) error {
- _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
- return err
-}
-
-func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
- var ms models.MediaServerResponse
- err := m.db.QueryRow(`
- SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
- FROM media_servers WHERE id = ?
- `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
- if err != nil {
- return nil, err
- }
- return &ms, nil
-}
-
-func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
- rows, err := m.db.Query(`
- SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
- FROM media_servers ORDER BY created_at DESC
- `)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- var servers []models.MediaServerResponse
- for rows.Next() {
- var ms models.MediaServerResponse
- err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
- if err != nil {
- return nil, err
- }
- servers = append(servers, ms)
- }
- return servers, nil
-}
-
-func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
- // 先将所有服务器设置为非默认
- if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
- return err
- }
-
- // 将指定ID的服务器设置为默认
- _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
- return err
-}
-
-func (m *MediaServerDB) Close() error {
- return m.db.Close()
-}
+package db
+
+import (
+ "database/sql"
+ "sync"
+
+ "github.com/ossrs/srs-sip/pkg/models"
+ _ "modernc.org/sqlite"
+)
+
+var (
+ instance *MediaServerDB
+ once sync.Once
+)
+
+type MediaServerDB struct {
+ models.MediaServerResponse
+ db *sql.DB
+}
+
+// GetInstance 返回 MediaServerDB 的单例实例
+func GetInstance(dbPath string) (*MediaServerDB, error) {
+ var err error
+ once.Do(func() {
+ instance, err = NewMediaServerDB(dbPath)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return instance, nil
+}
+
+func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
+ db, err := sql.Open("sqlite", dbPath)
+ if err != nil {
+ return nil, err
+ }
+
+ // 创建媒体服务器表
+ _, err = db.Exec(`
+ CREATE TABLE IF NOT EXISTS media_servers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ type TEXT NOT NULL,
+ name TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ port INTEGER NOT NULL,
+ username TEXT,
+ password TEXT,
+ secret TEXT,
+ is_default INTEGER NOT NULL DEFAULT 0,
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
+ )
+ `)
+ if err != nil {
+ return nil, err
+ }
+
+ return &MediaServerDB{db: db}, nil
+}
+
+// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
+func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
+ var ms models.MediaServerResponse
+ err := m.db.QueryRow(`
+ SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
+ FROM media_servers WHERE name = ? AND ip = ?
+ `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
+ if err != nil {
+ return nil, err
+ }
+ return &ms, nil
+}
+
+func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
+ _, err := m.db.Exec(`
+ INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ `, name, serverType, ip, port, username, password, secret, isDefault)
+ return err
+}
+
+// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
+func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
+ // 检查是否已存在
+ existing, err := m.GetMediaServerByNameAndIP(name, ip)
+ if err == nil && existing != nil {
+ // 已存在,更新记录
+ _, err = m.db.Exec(`
+ UPDATE media_servers
+ SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
+ WHERE name = ? AND ip = ?
+ `, serverType, port, username, password, secret, isDefault, name, ip)
+ return err
+ }
+
+ // 不存在,插入新记录
+ return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
+}
+
+func (m *MediaServerDB) DeleteMediaServer(id int) error {
+ _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
+ return err
+}
+
+func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
+ var ms models.MediaServerResponse
+ err := m.db.QueryRow(`
+ SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
+ FROM media_servers WHERE id = ?
+ `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
+ if err != nil {
+ return nil, err
+ }
+ return &ms, nil
+}
+
+func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
+ rows, err := m.db.Query(`
+ SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
+ FROM media_servers ORDER BY created_at DESC
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var servers []models.MediaServerResponse
+ for rows.Next() {
+ var ms models.MediaServerResponse
+ err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
+ if err != nil {
+ return nil, err
+ }
+ servers = append(servers, ms)
+ }
+ return servers, nil
+}
+
+func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
+ // 先将所有服务器设置为非默认
+ if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
+ return err
+ }
+
+ // 将指定ID的服务器设置为默认
+ _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
+ return err
+}
+
+func (m *MediaServerDB) Close() error {
+ return m.db.Close()
+}
diff --git a/pkg/media/zlm.go b/pkg/media/zlm.go
index ebc6619..a2abc98 100644
--- a/pkg/media/zlm.go
+++ b/pkg/media/zlm.go
@@ -1,59 +1,59 @@
-package media
-
-import (
- "context"
-
- "github.com/ossrs/go-oryx-lib/errors"
-)
-
-type Zlm struct {
- Ctx context.Context
- Schema string // The schema of ZLM, eg: http
- Addr string // The address of ZLM, eg: localhost:8085
- Secret string // The secret of ZLM, eg: ZLMediaKit_secret
-}
-
-// /index/api/openRtpServer
-// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
-func (z *Zlm) Publish(id, ssrc string) (int, error) {
-
- res := struct {
- Code int `json:"code"`
- Port int `json:"port"`
- }{}
-
- if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
- return 0, errors.Wrapf(err, "gb/v1/publish")
- }
- return res.Port, nil
-}
-
-// /index/api/closeRtpServer
-func (z *Zlm) Unpublish(id string) error {
- res := struct {
- Code int `json:"code"`
- }{}
- if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
- return errors.Wrapf(err, "gb/v1/publish")
- }
- return nil
-}
-
-// /index/api/getMediaList
-func (z *Zlm) GetStreamStatus(id string) (bool, error) {
- res := struct {
- Code int `json:"code"`
- }{}
- if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
- return false, errors.Wrapf(err, "gb/v1/publish")
- }
- return res.Code == 0, nil
-}
-
-func (z *Zlm) GetAddr() string {
- return z.Addr
-}
-
-func (z *Zlm) GetWebRTCAddr(id string) string {
- return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
-}
+package media
+
+import (
+ "context"
+
+ "github.com/ossrs/go-oryx-lib/errors"
+)
+
+type Zlm struct {
+ Ctx context.Context
+ Schema string // The schema of ZLM, eg: http
+ Addr string // The address of ZLM, eg: localhost:8085
+ Secret string // The secret of ZLM, eg: ZLMediaKit_secret
+}
+
+// /index/api/openRtpServer
+// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
+func (z *Zlm) Publish(id, ssrc string) (int, error) {
+
+ res := struct {
+ Code int `json:"code"`
+ Port int `json:"port"`
+ }{}
+
+ if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
+ return 0, errors.Wrapf(err, "gb/v1/publish")
+ }
+ return res.Port, nil
+}
+
+// /index/api/closeRtpServer
+func (z *Zlm) Unpublish(id string) error {
+ res := struct {
+ Code int `json:"code"`
+ }{}
+ if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
+ return errors.Wrapf(err, "gb/v1/publish")
+ }
+ return nil
+}
+
+// /index/api/getMediaList
+func (z *Zlm) GetStreamStatus(id string) (bool, error) {
+ res := struct {
+ Code int `json:"code"`
+ }{}
+ if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
+ return false, errors.Wrapf(err, "gb/v1/publish")
+ }
+ return res.Code == 0, nil
+}
+
+func (z *Zlm) GetAddr() string {
+ return z.Addr
+}
+
+func (z *Zlm) GetWebRTCAddr(id string) string {
+ return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
+}
diff --git a/pkg/models/gb28181.go b/pkg/models/gb28181.go
index e7c52e4..b8c79be 100644
--- a/pkg/models/gb28181.go
+++ b/pkg/models/gb28181.go
@@ -1,106 +1,106 @@
-package models
-
-import "encoding/xml"
-
-type Record struct {
- DeviceID string `xml:"DeviceID" json:"device_id"`
- Name string `xml:"Name" json:"name"`
- FilePath string `xml:"FilePath" json:"file_path"`
- Address string `xml:"Address" json:"address"`
- StartTime string `xml:"StartTime" json:"start_time"`
- EndTime string `xml:"EndTime" json:"end_time"`
- Secrecy int `xml:"Secrecy" json:"secrecy"`
- Type string `xml:"Type" json:"type"`
-}
-
-// Example XML structure for channel info:
-//
-// -
-// 34020000001320000002
-// 209
-// UNIVIEW
-// HIC6622-IR@X33-VF
-// IPC-B2202.7.11.230222
-// CivilCode
-// Address
-// 1
-// 75015310072008100002
-// 0
-// 1
-// 0
-// ON
-// 0.0000000
-// 0.0000000
-//
-// 1
-// 6/4/2
-// 0
-//
-//
-
-type ChannelInfo struct {
- DeviceID string `json:"device_id"`
- ParentID string `json:"parent_id"`
- Name string `json:"name"`
- Manufacturer string `json:"manufacturer"`
- Model string `json:"model"`
- Owner string `json:"owner"`
- CivilCode string `json:"civil_code"`
- Address string `json:"address"`
- Port int `json:"port"`
- Parental int `json:"parental"`
- SafetyWay int `json:"safety_way"`
- RegisterWay int `json:"register_way"`
- Secrecy int `json:"secrecy"`
- IPAddress string `json:"ip_address"`
- Status ChannelStatus `json:"status"`
- Longitude float64 `json:"longitude"`
- Latitude float64 `json:"latitude"`
- Info struct {
- PTZType int `json:"ptz_type"`
- Resolution string `json:"resolution"`
- DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
- } `json:"info"`
-
- // Custom fields
- Ssrc string `json:"ssrc"`
-}
-
-type ChannelStatus string
-
-// BasicParam
-//
-//
-//
-//
-//
-//
-//
-//
-//
-//
-//
-//
-//
-//
-type BasicParam struct {
- Name string `xml:"Name"`
- Expiration int `xml:"Expiration"`
- HeartBeatInterval int `xml:"HeartBeatInterval"`
- HeartBeatCount int `xml:"HeartBeatCount"`
-}
-
-type XmlMessageInfo struct {
- XMLName xml.Name
- CmdType string
- SN int
- DeviceID string
- DeviceName string
- Manufacturer string
- Model string
- Channel string
- DeviceList []ChannelInfo `xml:"DeviceList>Item"`
- RecordList []*Record `xml:"RecordList>Item"`
- BasicParam BasicParam `xml:"BasicParam"`
- SumNum int
-}
+package models
+
+import "encoding/xml"
+
+type Record struct {
+ DeviceID string `xml:"DeviceID" json:"device_id"`
+ Name string `xml:"Name" json:"name"`
+ FilePath string `xml:"FilePath" json:"file_path"`
+ Address string `xml:"Address" json:"address"`
+ StartTime string `xml:"StartTime" json:"start_time"`
+ EndTime string `xml:"EndTime" json:"end_time"`
+ Secrecy int `xml:"Secrecy" json:"secrecy"`
+ Type string `xml:"Type" json:"type"`
+}
+
+// Example XML structure for channel info:
+//
+// -
+// 34020000001320000002
+// 209
+// UNIVIEW
+// HIC6622-IR@X33-VF
+// IPC-B2202.7.11.230222
+// CivilCode
+// Address
+// 1
+// 75015310072008100002
+// 0
+// 1
+// 0
+// ON
+// 0.0000000
+// 0.0000000
+//
+// 1
+// 6/4/2
+// 0
+//
+//
+
+type ChannelInfo struct {
+ DeviceID string `json:"device_id"`
+ ParentID string `json:"parent_id"`
+ Name string `json:"name"`
+ Manufacturer string `json:"manufacturer"`
+ Model string `json:"model"`
+ Owner string `json:"owner"`
+ CivilCode string `json:"civil_code"`
+ Address string `json:"address"`
+ Port int `json:"port"`
+ Parental int `json:"parental"`
+ SafetyWay int `json:"safety_way"`
+ RegisterWay int `json:"register_way"`
+ Secrecy int `json:"secrecy"`
+ IPAddress string `json:"ip_address"`
+ Status ChannelStatus `json:"status"`
+ Longitude float64 `json:"longitude"`
+ Latitude float64 `json:"latitude"`
+ Info struct {
+ PTZType int `json:"ptz_type"`
+ Resolution string `json:"resolution"`
+ DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
+ } `json:"info"`
+
+ // Custom fields
+ Ssrc string `json:"ssrc"`
+}
+
+type ChannelStatus string
+
+// BasicParam
+//
+//
+//
+//
+//
+//
+//
+//
+//
+//
+//
+//
+//
+//
+type BasicParam struct {
+ Name string `xml:"Name"`
+ Expiration int `xml:"Expiration"`
+ HeartBeatInterval int `xml:"HeartBeatInterval"`
+ HeartBeatCount int `xml:"HeartBeatCount"`
+}
+
+type XmlMessageInfo struct {
+ XMLName xml.Name
+ CmdType string
+ SN int
+ DeviceID string
+ DeviceName string
+ Manufacturer string
+ Model string
+ Channel string
+ DeviceList []ChannelInfo `xml:"DeviceList>Item"`
+ RecordList []*Record `xml:"RecordList>Item"`
+ BasicParam BasicParam `xml:"BasicParam"`
+ SumNum int
+}
diff --git a/pkg/models/types.go b/pkg/models/types.go
index 6ab8a60..ba7e821 100644
--- a/pkg/models/types.go
+++ b/pkg/models/types.go
@@ -1,80 +1,80 @@
-package models
-
-type BaseRequest struct {
- DeviceID string `json:"device_id"`
- ChannelID string `json:"channel_id"`
-}
-
-type InviteRequest struct {
- BaseRequest
- MediaServerId int `json:"media_server_id"`
- PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
- SubStream int `json:"sub_stream"`
- StartTime int64 `json:"start_time"`
- EndTime int64 `json:"end_time"`
-}
-
-type InviteResponse struct {
- ChannelID string `json:"channel_id"`
- URL string `json:"url"`
-}
-
-type SessionRequest struct {
- BaseRequest
- URL string `json:"url"`
-}
-
-type ByeRequest struct {
- SessionRequest
-}
-
-type PauseRequest struct {
- SessionRequest
-}
-
-type ResumeRequest struct {
- SessionRequest
-}
-
-type SpeedRequest struct {
- SessionRequest
- Speed float32 `json:"speed"`
-}
-
-type PTZControlRequest struct {
- BaseRequest
- PTZ string `json:"ptz"`
- Speed string `json:"speed"`
-}
-
-type QueryRecordRequest struct {
- BaseRequest
- StartTime int64 `json:"start_time"`
- EndTime int64 `json:"end_time"`
-}
-
-type MediaServer struct {
- Name string `json:"name"`
- Type string `json:"type"`
- IP string `json:"ip"`
- Port int `json:"port"`
- Username string `json:"username"`
- Password string `json:"password"`
- Secret string `json:"secret"`
- IsDefault int `json:"is_default"`
-}
-
-type MediaServerRequest struct {
- MediaServer
-}
-
-type MediaServerResponse struct {
- MediaServer
- ID int `json:"id"`
- CreatedAt string `json:"created_at"`
-}
-
-type CommonResponse struct {
- Code int `json:"code"`
- Data interface{} `json:"data"`
-}
+package models
+
+type BaseRequest struct {
+ DeviceID string `json:"device_id"`
+ ChannelID string `json:"channel_id"`
+}
+
+type InviteRequest struct {
+ BaseRequest
+ MediaServerId int `json:"media_server_id"`
+ PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
+ SubStream int `json:"sub_stream"`
+ StartTime int64 `json:"start_time"`
+ EndTime int64 `json:"end_time"`
+}
+
+type InviteResponse struct {
+ ChannelID string `json:"channel_id"`
+ URL string `json:"url"`
+}
+
+type SessionRequest struct {
+ BaseRequest
+ URL string `json:"url"`
+}
+
+type ByeRequest struct {
+ SessionRequest
+}
+
+type PauseRequest struct {
+ SessionRequest
+}
+
+type ResumeRequest struct {
+ SessionRequest
+}
+
+type SpeedRequest struct {
+ SessionRequest
+ Speed float32 `json:"speed"`
+}
+
+type PTZControlRequest struct {
+ BaseRequest
+ PTZ string `json:"ptz"`
+ Speed string `json:"speed"`
+}
+
+type QueryRecordRequest struct {
+ BaseRequest
+ StartTime int64 `json:"start_time"`
+ EndTime int64 `json:"end_time"`
+}
+
+type MediaServer struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ IP string `json:"ip"`
+ Port int `json:"port"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Secret string `json:"secret"`
+ IsDefault int `json:"is_default"`
+}
+
+type MediaServerRequest struct {
+ MediaServer
+}
+
+type MediaServerResponse struct {
+ MediaServer
+ ID int `json:"id"`
+ CreatedAt string `json:"created_at"`
+}
+
+type CommonResponse struct {
+ Code int `json:"code"`
+ Data interface{} `json:"data"`
+}
diff --git a/pkg/models/types_test.go b/pkg/models/types_test.go
index 7d5f7e5..bd873bb 100644
--- a/pkg/models/types_test.go
+++ b/pkg/models/types_test.go
@@ -335,4 +335,3 @@ func TestCommonResponseWithDifferentDataTypes(t *testing.T) {
})
}
}
-
diff --git a/pkg/service/auth.go b/pkg/service/auth.go
index 8f1538b..06fc879 100644
--- a/pkg/service/auth.go
+++ b/pkg/service/auth.go
@@ -1,92 +1,92 @@
-package service
-
-import (
- "crypto/md5"
- "crypto/rand"
- "encoding/hex"
- "fmt"
- "strings"
-)
-
-// AuthInfo 存储解析后的认证信息
-type AuthInfo struct {
- Username string
- Realm string
- Nonce string
- URI string
- Response string
- Algorithm string
- Method string
-}
-
-// GenerateNonce 生成随机 nonce 字符串
-func GenerateNonce() string {
- b := make([]byte, 16)
- rand.Read(b)
- return fmt.Sprintf("%x", b)
-}
-
-// ParseAuthorization 解析 SIP Authorization 头
-// Authorization: Digest username="34020000001320000001",realm="3402000000",
-// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
-// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
-func ParseAuthorization(auth string) *AuthInfo {
- auth = strings.TrimPrefix(auth, "Digest ")
- parts := strings.Split(auth, ",")
- result := &AuthInfo{}
-
- for _, part := range parts {
- part = strings.TrimSpace(part)
- if !strings.Contains(part, "=") {
- continue
- }
-
- kv := strings.SplitN(part, "=", 2)
- key := strings.TrimSpace(kv[0])
- value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
-
- switch key {
- case "username":
- result.Username = value
- case "realm":
- result.Realm = value
- case "nonce":
- result.Nonce = value
- case "uri":
- result.URI = value
- case "response":
- result.Response = value
- case "algorithm":
- result.Algorithm = value
- }
- }
-
- return result
-}
-
-// ValidateAuth 验证 SIP 认证信息
-func ValidateAuth(authInfo *AuthInfo, password string) bool {
- if authInfo == nil {
- return false
- }
-
- // 默认方法为 REGISTER
- method := "REGISTER"
- if authInfo.Method != "" {
- method = authInfo.Method
- }
-
- // 计算 MD5 哈希
- ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
- ha2 := md5Hex(method + ":" + authInfo.URI)
- correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
-
- return authInfo.Response == correctResponse
-}
-
-// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
-func md5Hex(s string) string {
- hash := md5.New()
- hash.Write([]byte(s))
- return hex.EncodeToString(hash.Sum(nil))
-}
\ No newline at end of file
+package service
+
+import (
+ "crypto/md5"
+ "crypto/rand"
+ "encoding/hex"
+ "fmt"
+ "strings"
+)
+
+// AuthInfo 存储解析后的认证信息
+type AuthInfo struct {
+ Username string
+ Realm string
+ Nonce string
+ URI string
+ Response string
+ Algorithm string
+ Method string
+}
+
+// GenerateNonce 生成随机 nonce 字符串
+func GenerateNonce() string {
+ b := make([]byte, 16)
+ rand.Read(b)
+ return fmt.Sprintf("%x", b)
+}
+
+// ParseAuthorization 解析 SIP Authorization 头
+// Authorization: Digest username="34020000001320000001",realm="3402000000",
+// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
+// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
+func ParseAuthorization(auth string) *AuthInfo {
+ auth = strings.TrimPrefix(auth, "Digest ")
+ parts := strings.Split(auth, ",")
+ result := &AuthInfo{}
+
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if !strings.Contains(part, "=") {
+ continue
+ }
+
+ kv := strings.SplitN(part, "=", 2)
+ key := strings.TrimSpace(kv[0])
+ value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
+
+ switch key {
+ case "username":
+ result.Username = value
+ case "realm":
+ result.Realm = value
+ case "nonce":
+ result.Nonce = value
+ case "uri":
+ result.URI = value
+ case "response":
+ result.Response = value
+ case "algorithm":
+ result.Algorithm = value
+ }
+ }
+
+ return result
+}
+
+// ValidateAuth 验证 SIP 认证信息
+func ValidateAuth(authInfo *AuthInfo, password string) bool {
+ if authInfo == nil {
+ return false
+ }
+
+ // 默认方法为 REGISTER
+ method := "REGISTER"
+ if authInfo.Method != "" {
+ method = authInfo.Method
+ }
+
+ // 计算 MD5 哈希
+ ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
+ ha2 := md5Hex(method + ":" + authInfo.URI)
+ correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
+
+ return authInfo.Response == correctResponse
+}
+
+// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
+func md5Hex(s string) string {
+ hash := md5.New()
+ hash.Write([]byte(s))
+ return hex.EncodeToString(hash.Sum(nil))
+}
diff --git a/pkg/service/auth_test.go b/pkg/service/auth_test.go
index f0f0835..6a4f3b6 100644
--- a/pkg/service/auth_test.go
+++ b/pkg/service/auth_test.go
@@ -343,4 +343,3 @@ func TestParseAuthorizationQuotedValues(t *testing.T) {
t.Logf("Realm value: '%s'", result.Realm)
}
}
-
diff --git a/pkg/service/cascade.go b/pkg/service/cascade.go
index e7360c6..d214508 100644
--- a/pkg/service/cascade.go
+++ b/pkg/service/cascade.go
@@ -1,17 +1,17 @@
-package service
-
-import (
- "context"
-
- "github.com/emiago/sipgo"
- "github.com/ossrs/srs-sip/pkg/config"
-)
-
-type Cascade struct {
- ua *sipgo.UserAgent
- sipCli *sipgo.Client
- sipSvr *sipgo.Server
-
- ctx context.Context
- conf *config.MainConfig
-}
+package service
+
+import (
+ "context"
+
+ "github.com/emiago/sipgo"
+ "github.com/ossrs/srs-sip/pkg/config"
+)
+
+type Cascade struct {
+ ua *sipgo.UserAgent
+ sipCli *sipgo.Client
+ sipSvr *sipgo.Server
+
+ ctx context.Context
+ conf *config.MainConfig
+}
diff --git a/pkg/service/inbound.go b/pkg/service/inbound.go
index 3114a65..7737ac5 100644
--- a/pkg/service/inbound.go
+++ b/pkg/service/inbound.go
@@ -1,171 +1,171 @@
-package service
-
-import (
- "bytes"
- "encoding/xml"
- "fmt"
- "log/slog"
- "net"
- "net/http"
- "strconv"
-
- "github.com/emiago/sipgo/sip"
- "github.com/ossrs/srs-sip/pkg/models"
- "github.com/ossrs/srs-sip/pkg/service/stack"
- "golang.org/x/net/html/charset"
-)
-
-const GB28181_ID_LENGTH = 20
-
-func (s *UAS) isSameIP(addr1, addr2 string) bool {
- ip1, _, err1 := net.SplitHostPort(addr1)
- ip2, _, err2 := net.SplitHostPort(addr2)
-
- // 如果解析出错,回退到完整字符串比较
- if err1 != nil || err2 != nil {
- return addr1 == addr2
- }
-
- return ip1 == ip2
-}
-
-func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
- id := req.From().Address.User
- if len(id) != GB28181_ID_LENGTH {
- slog.Error("invalid device ID")
- return
- }
-
- slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
-
- if s.conf.GB28181.Auth.Enable {
- // Check if Authorization header exists
- authHeader := req.GetHeaders("Authorization")
-
- // If no Authorization header, send 401 response to request authentication
- if len(authHeader) == 0 {
- nonce := GenerateNonce()
- resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
- _ = tx.Respond(resp)
- return
- }
-
- // Validate Authorization
- authInfo := ParseAuthorization(authHeader[0].Value())
- if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
- slog.Error("auth failed", "device_id", id, "source", req.Source())
- s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
- return
- }
- }
-
- isUnregister := false
- if exps := req.GetHeaders("Expires"); len(exps) > 0 {
- exp := exps[0]
- expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
- if err != nil {
- slog.Error("parse expires header error", "error", err.Error())
- return
- }
- if expSec == 0 {
- isUnregister = true
- }
- } else {
- slog.Error("empty expires header")
- return
- }
-
- if isUnregister {
- DM.RemoveDevice(id)
- slog.Warn("Device unregistered", "device_id", id)
- return
- } else {
- if d, ok := DM.GetDevice(id); !ok {
- DM.AddDevice(id, &DeviceInfo{
- DeviceID: id,
- SourceAddr: req.Source(),
- NetworkType: req.Transport(),
- })
- s.respondRegister(req, http.StatusOK, "OK", tx)
- slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
-
- go s.ConfigDownload(id)
- go s.Catalog(id)
- } else {
- if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
- slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
- // TODO: 如果ID重复,应采用虚拟ID
- s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
- } else {
- d.SourceAddr = req.Source()
- d.NetworkType = req.Transport()
- DM.UpdateDevice(id, d)
- s.respondRegister(req, http.StatusOK, "OK", tx)
-
- slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
- }
- }
- }
-}
-
-func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
- res := stack.NewRegisterResponse(req, code, reason)
- _ = tx.Respond(res)
-
-}
-
-func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
- id := req.From().Address.User
- if len(id) != 20 {
- slog.Error("invalid device ID", "request", req.String())
- }
-
- slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
-
- temp := &models.XmlMessageInfo{}
- decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
- decoder.CharsetReader = charset.NewReaderLabel
- if err := decoder.Decode(temp); err != nil {
- slog.Error("decode message error", "error", err.Error(), "message", req.Body())
- }
-
- slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
-
- var body string
- switch temp.CmdType {
- case "Keepalive":
- if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
- // 更新设备心跳时间
- DM.UpdateDeviceHeartbeat(temp.DeviceID)
- } else {
- tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
- return
- }
- case "SensorCatalog": // 兼容宇视,非国标
- case "Catalog":
- DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
- //go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
- case "ConfigDownload":
- DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
- case "Alarm":
- slog.Info("Alarm")
- case "RecordInfo":
- // 从 recordQueryResults 中获取对应通道的结果通道
- if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
- // 发送查询结果
- resultChan := ch.(chan *models.XmlMessageInfo)
- resultChan <- temp
- }
- default:
- slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
- response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
- tx.Respond(response)
- return
- }
- tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
-}
-
-func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
- slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
- tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
-}
+package service
+
+import (
+ "bytes"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "net"
+ "net/http"
+ "strconv"
+
+ "github.com/emiago/sipgo/sip"
+ "github.com/ossrs/srs-sip/pkg/models"
+ "github.com/ossrs/srs-sip/pkg/service/stack"
+ "golang.org/x/net/html/charset"
+)
+
+const GB28181_ID_LENGTH = 20
+
+func (s *UAS) isSameIP(addr1, addr2 string) bool {
+ ip1, _, err1 := net.SplitHostPort(addr1)
+ ip2, _, err2 := net.SplitHostPort(addr2)
+
+ // 如果解析出错,回退到完整字符串比较
+ if err1 != nil || err2 != nil {
+ return addr1 == addr2
+ }
+
+ return ip1 == ip2
+}
+
+func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
+ id := req.From().Address.User
+ if len(id) != GB28181_ID_LENGTH {
+ slog.Error("invalid device ID")
+ return
+ }
+
+ slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
+
+ if s.conf.GB28181.Auth.Enable {
+ // Check if Authorization header exists
+ authHeader := req.GetHeaders("Authorization")
+
+ // If no Authorization header, send 401 response to request authentication
+ if len(authHeader) == 0 {
+ nonce := GenerateNonce()
+ resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
+ _ = tx.Respond(resp)
+ return
+ }
+
+ // Validate Authorization
+ authInfo := ParseAuthorization(authHeader[0].Value())
+ if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
+ slog.Error("auth failed", "device_id", id, "source", req.Source())
+ s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
+ return
+ }
+ }
+
+ isUnregister := false
+ if exps := req.GetHeaders("Expires"); len(exps) > 0 {
+ exp := exps[0]
+ expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
+ if err != nil {
+ slog.Error("parse expires header error", "error", err.Error())
+ return
+ }
+ if expSec == 0 {
+ isUnregister = true
+ }
+ } else {
+ slog.Error("empty expires header")
+ return
+ }
+
+ if isUnregister {
+ DM.RemoveDevice(id)
+ slog.Warn("Device unregistered", "device_id", id)
+ return
+ } else {
+ if d, ok := DM.GetDevice(id); !ok {
+ DM.AddDevice(id, &DeviceInfo{
+ DeviceID: id,
+ SourceAddr: req.Source(),
+ NetworkType: req.Transport(),
+ })
+ s.respondRegister(req, http.StatusOK, "OK", tx)
+ slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
+
+ go s.ConfigDownload(id)
+ go s.Catalog(id)
+ } else {
+ if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
+ slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
+ // TODO: 如果ID重复,应采用虚拟ID
+ s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
+ } else {
+ d.SourceAddr = req.Source()
+ d.NetworkType = req.Transport()
+ DM.UpdateDevice(id, d)
+ s.respondRegister(req, http.StatusOK, "OK", tx)
+
+ slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
+ }
+ }
+ }
+}
+
+func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
+ res := stack.NewRegisterResponse(req, code, reason)
+ _ = tx.Respond(res)
+
+}
+
+func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
+ id := req.From().Address.User
+ if len(id) != 20 {
+ slog.Error("invalid device ID", "request", req.String())
+ }
+
+ slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
+
+ temp := &models.XmlMessageInfo{}
+ decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
+ decoder.CharsetReader = charset.NewReaderLabel
+ if err := decoder.Decode(temp); err != nil {
+ slog.Error("decode message error", "error", err.Error(), "message", req.Body())
+ }
+
+ slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
+
+ var body string
+ switch temp.CmdType {
+ case "Keepalive":
+ if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
+ // 更新设备心跳时间
+ DM.UpdateDeviceHeartbeat(temp.DeviceID)
+ } else {
+ tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
+ return
+ }
+ case "SensorCatalog": // 兼容宇视,非国标
+ case "Catalog":
+ DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
+ //go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
+ case "ConfigDownload":
+ DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
+ case "Alarm":
+ slog.Info("Alarm")
+ case "RecordInfo":
+ // 从 recordQueryResults 中获取对应通道的结果通道
+ if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
+ // 发送查询结果
+ resultChan := ch.(chan *models.XmlMessageInfo)
+ resultChan <- temp
+ }
+ default:
+ slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
+ response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
+ tx.Respond(response)
+ return
+ }
+ tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
+}
+
+func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
+ slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
+ tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
+}
diff --git a/pkg/service/ptz.go b/pkg/service/ptz.go
index e561643..d62a4fa 100644
--- a/pkg/service/ptz.go
+++ b/pkg/service/ptz.go
@@ -1,81 +1,81 @@
-package service
-
-import "fmt"
-
-var (
- ptzCmdMap = map[string]uint8{
- "stop": 0,
- "right": 1,
- "left": 2,
- "down": 4,
- "downright": 5,
- "downleft": 6,
- "up": 8,
- "upright": 9,
- "upleft": 10,
- "zoomin": 16,
- "zoomout": 32,
- }
-
- ptzSpeedMap = map[string]uint8{
- "1": 25,
- "2": 50,
- "3": 75,
- "4": 100,
- "5": 125,
- "6": 150,
- "7": 175,
- "8": 200,
- "9": 225,
- "10": 255,
- }
-
- defaultSpeed uint8 = 125
-)
-
-func getPTZSpeed(speed string) uint8 {
- if v, ok := ptzSpeedMap[speed]; ok {
- return v
- }
- return defaultSpeed
-}
-
-func toPTZCmd(cmdName, speed string) (string, error) {
- cmdCode, ok := ptzCmdMap[cmdName]
- if !ok {
- return "", fmt.Errorf("invalid ptz command: %q", cmdName)
- }
-
- speedValue := getPTZSpeed(speed)
-
- var horizontalSpeed, verticalSpeed, zSpeed uint8
-
- switch cmdName {
- case "left", "right":
- horizontalSpeed = speedValue
- verticalSpeed = 0
- case "up", "down":
- verticalSpeed = speedValue
- horizontalSpeed = 0
- case "upleft", "upright", "downleft", "downright":
- verticalSpeed = speedValue
- horizontalSpeed = speedValue
- case "zoomin", "zoomout":
- zSpeed = speedValue << 4 // zoom速度在高4位
- default:
- horizontalSpeed = 0
- verticalSpeed = 0
- zSpeed = 0
- }
-
- sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
- checksum := uint8(sum % 256)
-
- return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
- cmdCode,
- horizontalSpeed,
- verticalSpeed,
- zSpeed,
- checksum,
- ), nil
-}
+package service
+
+import "fmt"
+
+var (
+ ptzCmdMap = map[string]uint8{
+ "stop": 0,
+ "right": 1,
+ "left": 2,
+ "down": 4,
+ "downright": 5,
+ "downleft": 6,
+ "up": 8,
+ "upright": 9,
+ "upleft": 10,
+ "zoomin": 16,
+ "zoomout": 32,
+ }
+
+ ptzSpeedMap = map[string]uint8{
+ "1": 25,
+ "2": 50,
+ "3": 75,
+ "4": 100,
+ "5": 125,
+ "6": 150,
+ "7": 175,
+ "8": 200,
+ "9": 225,
+ "10": 255,
+ }
+
+ defaultSpeed uint8 = 125
+)
+
+func getPTZSpeed(speed string) uint8 {
+ if v, ok := ptzSpeedMap[speed]; ok {
+ return v
+ }
+ return defaultSpeed
+}
+
+func toPTZCmd(cmdName, speed string) (string, error) {
+ cmdCode, ok := ptzCmdMap[cmdName]
+ if !ok {
+ return "", fmt.Errorf("invalid ptz command: %q", cmdName)
+ }
+
+ speedValue := getPTZSpeed(speed)
+
+ var horizontalSpeed, verticalSpeed, zSpeed uint8
+
+ switch cmdName {
+ case "left", "right":
+ horizontalSpeed = speedValue
+ verticalSpeed = 0
+ case "up", "down":
+ verticalSpeed = speedValue
+ horizontalSpeed = 0
+ case "upleft", "upright", "downleft", "downright":
+ verticalSpeed = speedValue
+ horizontalSpeed = speedValue
+ case "zoomin", "zoomout":
+ zSpeed = speedValue << 4 // zoom速度在高4位
+ default:
+ horizontalSpeed = 0
+ verticalSpeed = 0
+ zSpeed = 0
+ }
+
+ sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
+ checksum := uint8(sum % 256)
+
+ return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
+ cmdCode,
+ horizontalSpeed,
+ verticalSpeed,
+ zSpeed,
+ checksum,
+ ), nil
+}
diff --git a/pkg/service/ptz_test.go b/pkg/service/ptz_test.go
index 14e5e88..a715a0f 100644
--- a/pkg/service/ptz_test.go
+++ b/pkg/service/ptz_test.go
@@ -141,7 +141,7 @@ func TestToPTZCmdSpecificCases(t *testing.T) {
func TestToPTZCmdWithDifferentSpeeds(t *testing.T) {
speeds := []string{"1", "5", "10"}
-
+
for _, speed := range speeds {
t.Run("Right with speed "+speed, func(t *testing.T) {
result, err := toPTZCmd("right", speed)
@@ -196,4 +196,3 @@ func TestPTZSpeedMap(t *testing.T) {
})
}
}
-
diff --git a/pkg/service/stack/request.go b/pkg/service/stack/request.go
index ebb7e54..c9567ae 100644
--- a/pkg/service/stack/request.go
+++ b/pkg/service/stack/request.go
@@ -1,68 +1,68 @@
-package stack
-
-import (
- "github.com/emiago/sipgo/sip"
- "github.com/ossrs/go-oryx-lib/errors"
-)
-
-type OutboundConfig struct {
- Transport string
- Via string
- From string
- To string
-}
-
-func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
- if len(conf.From) != 20 || len(conf.To) != 20 {
- return nil, errors.Errorf("From or To length is not 20")
- }
-
- dest := conf.Via
- to := sip.Uri{User: conf.To, Host: conf.To[:10]}
- from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
-
- fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
- fromHeader.Params.Add("tag", sip.GenerateTagN(16))
-
- req := sip.NewRequest(method, to)
- req.AppendHeader(fromHeader)
- req.AppendHeader(&sip.ToHeader{Address: to})
- req.AppendHeader(&sip.ContactHeader{Address: *from})
- req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
- req.SetBody(body)
- req.SetDestination(dest)
- req.SetTransport(conf.Transport)
-
- return req, nil
-}
-
-func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
- req, err := NewRequest(sip.REGISTER, nil, conf)
- if err != nil {
- return nil, err
- }
- req.AppendHeader(sip.NewHeader("Expires", "3600"))
-
- return req, nil
-}
-
-func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
- req, err := NewRequest(sip.INVITE, body, conf)
- if err != nil {
- return nil, err
- }
- req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
- req.AppendHeader(sip.NewHeader("Subject", subject))
-
- return req, nil
-}
-
-func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
- req, err := NewRequest(sip.MESSAGE, body, conf)
- if err != nil {
- return nil, err
- }
- req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
-
- return req, nil
-}
+package stack
+
+import (
+ "github.com/emiago/sipgo/sip"
+ "github.com/ossrs/go-oryx-lib/errors"
+)
+
+type OutboundConfig struct {
+ Transport string
+ Via string
+ From string
+ To string
+}
+
+func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
+ if len(conf.From) != 20 || len(conf.To) != 20 {
+ return nil, errors.Errorf("From or To length is not 20")
+ }
+
+ dest := conf.Via
+ to := sip.Uri{User: conf.To, Host: conf.To[:10]}
+ from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
+
+ fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
+ fromHeader.Params.Add("tag", sip.GenerateTagN(16))
+
+ req := sip.NewRequest(method, to)
+ req.AppendHeader(fromHeader)
+ req.AppendHeader(&sip.ToHeader{Address: to})
+ req.AppendHeader(&sip.ContactHeader{Address: *from})
+ req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
+ req.SetBody(body)
+ req.SetDestination(dest)
+ req.SetTransport(conf.Transport)
+
+ return req, nil
+}
+
+func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
+ req, err := NewRequest(sip.REGISTER, nil, conf)
+ if err != nil {
+ return nil, err
+ }
+ req.AppendHeader(sip.NewHeader("Expires", "3600"))
+
+ return req, nil
+}
+
+func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
+ req, err := NewRequest(sip.INVITE, body, conf)
+ if err != nil {
+ return nil, err
+ }
+ req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
+ req.AppendHeader(sip.NewHeader("Subject", subject))
+
+ return req, nil
+}
+
+func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
+ req, err := NewRequest(sip.MESSAGE, body, conf)
+ if err != nil {
+ return nil, err
+ }
+ req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
+
+ return req, nil
+}
diff --git a/pkg/service/stack/response.go b/pkg/service/stack/response.go
index c3d4149..3009e09 100644
--- a/pkg/service/stack/response.go
+++ b/pkg/service/stack/response.go
@@ -1,41 +1,41 @@
-package stack
-
-import (
- "fmt"
- "time"
-
- "github.com/emiago/sipgo/sip"
-)
-
-const TIME_LAYOUT = "2024-01-01T00:00:00"
-const EXPIRES_TIME = 3600
-
-func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
- resp := sip.NewResponseFromRequest(req, code, reason, nil)
-
- newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
- newTo.Params.Add("tag", sip.GenerateTagN(10))
-
- resp.ReplaceHeader(newTo)
- resp.RemoveHeader("Allow")
-
- return resp
-}
-
-func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
- resp := newResponse(req, code, reason)
-
- expires := sip.ExpiresHeader(EXPIRES_TIME)
- resp.AppendHeader(&expires)
- resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
-
- return resp
-}
-
-func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
- resp := newResponse(req, code, reason)
-
- resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
-
- return resp
-}
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/emiago/sipgo/sip"
+)
+
+const TIME_LAYOUT = "2024-01-01T00:00:00"
+const EXPIRES_TIME = 3600
+
+func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
+ resp := sip.NewResponseFromRequest(req, code, reason, nil)
+
+ newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
+ newTo.Params.Add("tag", sip.GenerateTagN(10))
+
+ resp.ReplaceHeader(newTo)
+ resp.RemoveHeader("Allow")
+
+ return resp
+}
+
+func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
+ resp := newResponse(req, code, reason)
+
+ expires := sip.ExpiresHeader(EXPIRES_TIME)
+ resp.AppendHeader(&expires)
+ resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
+
+ return resp
+}
+
+func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
+ resp := newResponse(req, code, reason)
+
+ resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
+
+ return resp
+}
diff --git a/pkg/service/uac.go b/pkg/service/uac.go
index 066f835..dfd773c 100644
--- a/pkg/service/uac.go
+++ b/pkg/service/uac.go
@@ -1,122 +1,122 @@
-package service
-
-import (
- "context"
- "fmt"
- "log/slog"
-
- "github.com/emiago/sipgo"
- "github.com/emiago/sipgo/sip"
- "github.com/ossrs/go-oryx-lib/errors"
- "github.com/ossrs/srs-sip/pkg/config"
- "github.com/ossrs/srs-sip/pkg/service/stack"
-)
-
-const (
- UserAgent = "SRS-SIP/1.0"
-)
-
-type UAC struct {
- *Cascade
-
- SN uint32
- LocalIP string
-}
-
-func NewUac() *UAC {
- ip, err := config.GetLocalIP()
- if err != nil {
- slog.Error("get local ip failed", "error", err)
- return nil
- }
-
- c := &UAC{
- Cascade: &Cascade{},
- LocalIP: ip,
- }
- return c
-}
-
-func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
- var err error
-
- c.ctx = context.Background()
- c.conf = r0.(*config.MainConfig)
-
- if agent == nil {
- ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
- if err != nil {
- return err
- }
- agent = ua
- }
-
- c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
- if err != nil {
- return err
- }
-
- c.sipSvr, err = sipgo.NewServer(agent)
- if err != nil {
- return err
- }
-
- c.sipSvr.OnInvite(c.onInvite)
- c.sipSvr.OnBye(c.onBye)
- c.sipSvr.OnMessage(c.onMessage)
-
- go c.doRegister()
-
- return nil
-}
-
-func (c *UAC) Stop() {
- // TODO: 断开所有当前连接
- c.sipCli.Close()
- c.sipSvr.Close()
-}
-
-func (c *UAC) doRegister() error {
- r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
- From: "34020000001110000001",
- To: "34020000002000000001",
- Transport: "UDP",
- Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
- })
- tx, err := c.sipCli.TransactionRequest(c.ctx, r)
- if err != nil {
- return errors.Wrapf(err, "transaction request error")
- }
-
- rs, _ := c.getResponse(tx)
- slog.Info("register response", "response", rs.String())
- return nil
-}
-
-func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
- switch req.Method {
- case "INVITE":
- c.onInvite(req, tx)
- }
-}
-
-func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
- slog.Debug("onInvite")
-}
-
-func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
- slog.Debug("onBye")
-}
-
-func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
- slog.Debug("onMessage", "request", req.String())
-}
-
-func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
- select {
- case <-tx.Done():
- return nil, fmt.Errorf("transaction died")
- case res := <-tx.Responses():
- return res, nil
- }
-}
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+
+ "github.com/emiago/sipgo"
+ "github.com/emiago/sipgo/sip"
+ "github.com/ossrs/go-oryx-lib/errors"
+ "github.com/ossrs/srs-sip/pkg/config"
+ "github.com/ossrs/srs-sip/pkg/service/stack"
+)
+
+const (
+ UserAgent = "SRS-SIP/1.0"
+)
+
+type UAC struct {
+ *Cascade
+
+ SN uint32
+ LocalIP string
+}
+
+func NewUac() *UAC {
+ ip, err := config.GetLocalIP()
+ if err != nil {
+ slog.Error("get local ip failed", "error", err)
+ return nil
+ }
+
+ c := &UAC{
+ Cascade: &Cascade{},
+ LocalIP: ip,
+ }
+ return c
+}
+
+func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
+ var err error
+
+ c.ctx = context.Background()
+ c.conf = r0.(*config.MainConfig)
+
+ if agent == nil {
+ ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
+ if err != nil {
+ return err
+ }
+ agent = ua
+ }
+
+ c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
+ if err != nil {
+ return err
+ }
+
+ c.sipSvr, err = sipgo.NewServer(agent)
+ if err != nil {
+ return err
+ }
+
+ c.sipSvr.OnInvite(c.onInvite)
+ c.sipSvr.OnBye(c.onBye)
+ c.sipSvr.OnMessage(c.onMessage)
+
+ go c.doRegister()
+
+ return nil
+}
+
+func (c *UAC) Stop() {
+ // TODO: 断开所有当前连接
+ c.sipCli.Close()
+ c.sipSvr.Close()
+}
+
+func (c *UAC) doRegister() error {
+ r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
+ From: "34020000001110000001",
+ To: "34020000002000000001",
+ Transport: "UDP",
+ Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
+ })
+ tx, err := c.sipCli.TransactionRequest(c.ctx, r)
+ if err != nil {
+ return errors.Wrapf(err, "transaction request error")
+ }
+
+ rs, _ := c.getResponse(tx)
+ slog.Info("register response", "response", rs.String())
+ return nil
+}
+
+func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
+ switch req.Method {
+ case "INVITE":
+ c.onInvite(req, tx)
+ }
+}
+
+func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
+ slog.Debug("onInvite")
+}
+
+func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
+ slog.Debug("onBye")
+}
+
+func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
+ slog.Debug("onMessage", "request", req.String())
+}
+
+func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
+ select {
+ case <-tx.Done():
+ return nil, fmt.Errorf("transaction died")
+ case res := <-tx.Responses():
+ return res, nil
+ }
+}
diff --git a/pkg/utils/logger.go b/pkg/utils/logger.go
index 3f9e5b1..f01df8d 100644
--- a/pkg/utils/logger.go
+++ b/pkg/utils/logger.go
@@ -1,201 +1,201 @@
-package utils
-
-import (
- "context"
- "fmt"
- "io"
- "log/slog"
- "os"
- "path/filepath"
- "strings"
- "sync"
-)
-
-var logLevelMap = map[string]slog.Level{
- "debug": slog.LevelDebug,
- "info": slog.LevelInfo,
- "warn": slog.LevelWarn,
- "error": slog.LevelError,
-}
-
-// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
-type CustomFormatHandler struct {
- mu sync.Mutex
- w io.Writer
- level slog.Level
- attrs []slog.Attr
- groups []string
-}
-
-// NewCustomFormatHandler 创建一个新的自定义格式处理器
-func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
- if opts == nil {
- opts = &slog.HandlerOptions{}
- }
-
- // 获取日志级别,如果opts.Level是nil则默认为Info
- var level slog.Level
- if opts.Level != nil {
- level = opts.Level.Level()
- } else {
- level = slog.LevelInfo
- }
-
- return &CustomFormatHandler{
- w: w,
- level: level,
- }
-}
-
-// Enabled 实现 slog.Handler 接口
-func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
- return level >= h.level
-}
-
-// Handle 实现 slog.Handler 接口,以自定义格式输出日志
-func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
- h.mu.Lock()
- defer h.mu.Unlock()
-
- // 时间格式
- timeStr := record.Time.Format("2006-01-02 15:04:05.000")
-
- // 日志级别
- var levelStr string
- switch {
- case record.Level >= slog.LevelError:
- levelStr = "ERROR"
- case record.Level >= slog.LevelWarn:
- levelStr = "WARN "
- case record.Level >= slog.LevelInfo:
- levelStr = "INFO "
- default:
- levelStr = "DEBUG"
- }
-
- // 构建日志行
- logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
-
- // 处理其他属性
- var attrs []string
- record.Attrs(func(attr slog.Attr) bool {
- attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
- return true
- })
-
- if len(attrs) > 0 {
- logLine += " " + strings.Join(attrs, " ")
- }
-
- // 写入日志
- _, err := fmt.Fprintln(h.w, logLine)
- return err
-}
-
-// WithAttrs 实现 slog.Handler 接口
-func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
- h2 := *h
- h2.attrs = append(h.attrs[:], attrs...)
- return &h2
-}
-
-// WithGroup 实现 slog.Handler 接口
-func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
- h2 := *h
- h2.groups = append(h.groups[:], name)
- return &h2
-}
-
-// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
-type MultiHandler struct {
- handlers []slog.Handler
-}
-
-// Enabled 实现 slog.Handler 接口
-func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
- // 如果任何一个处理器启用了该级别,则返回 true
- for _, handler := range h.handlers {
- if handler.Enabled(ctx, level) {
- return true
- }
- }
- return false
-}
-
-// Handle 实现 slog.Handler 接口
-func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
- // 将记录发送到所有处理器
- for _, handler := range h.handlers {
- if handler.Enabled(ctx, record.Level) {
- if err := handler.Handle(ctx, record); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-// WithAttrs 实现 slog.Handler 接口
-func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
- newHandlers := make([]slog.Handler, len(h.handlers))
- for i, handler := range h.handlers {
- newHandlers[i] = handler.WithAttrs(attrs)
- }
- return &MultiHandler{handlers: newHandlers}
-}
-
-// WithGroup 实现 slog.Handler 接口
-func (h *MultiHandler) WithGroup(name string) slog.Handler {
- newHandlers := make([]slog.Handler, len(h.handlers))
- for i, handler := range h.handlers {
- newHandlers[i] = handler.WithGroup(name)
- }
- return &MultiHandler{handlers: newHandlers}
-}
-
-// SetupLogger 设置日志输出
-func SetupLogger(logLevel string, logFile string) error {
- // 创建标准错误输出的处理器,使用自定义格式
- stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
- Level: logLevelMap[logLevel],
- })
-
- // 如果没有指定日志文件,则仅使用标准错误处理器
- if logFile == "" {
- slog.SetDefault(slog.New(stdHandler))
- return nil
- }
-
- // 确保日志文件所在目录存在
- logDir := filepath.Dir(logFile)
- if err := os.MkdirAll(logDir, 0755); err != nil {
- return err
- }
-
- // 打开日志文件,如果不存在则创建,追加写入模式
- file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
- if err != nil {
- return err
- }
-
- // 创建文件输出的处理器,使用自定义格式
- fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
- Level: logLevelMap[logLevel],
- })
-
- // 创建多输出处理器
- multiHandler := &MultiHandler{
- handlers: []slog.Handler{stdHandler, fileHandler},
- }
-
- // 设置全局日志处理器
- slog.SetDefault(slog.New(multiHandler))
- return nil
-}
-
-// InitDefaultLogger 初始化默认日志处理器
-func InitDefaultLogger(level slog.Level) {
- slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
- Level: level,
- })))
-}
+package utils
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+)
+
+var logLevelMap = map[string]slog.Level{
+ "debug": slog.LevelDebug,
+ "info": slog.LevelInfo,
+ "warn": slog.LevelWarn,
+ "error": slog.LevelError,
+}
+
+// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
+type CustomFormatHandler struct {
+ mu sync.Mutex
+ w io.Writer
+ level slog.Level
+ attrs []slog.Attr
+ groups []string
+}
+
+// NewCustomFormatHandler 创建一个新的自定义格式处理器
+func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
+ if opts == nil {
+ opts = &slog.HandlerOptions{}
+ }
+
+ // 获取日志级别,如果opts.Level是nil则默认为Info
+ var level slog.Level
+ if opts.Level != nil {
+ level = opts.Level.Level()
+ } else {
+ level = slog.LevelInfo
+ }
+
+ return &CustomFormatHandler{
+ w: w,
+ level: level,
+ }
+}
+
+// Enabled 实现 slog.Handler 接口
+func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ return level >= h.level
+}
+
+// Handle 实现 slog.Handler 接口,以自定义格式输出日志
+func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ // 时间格式
+ timeStr := record.Time.Format("2006-01-02 15:04:05.000")
+
+ // 日志级别
+ var levelStr string
+ switch {
+ case record.Level >= slog.LevelError:
+ levelStr = "ERROR"
+ case record.Level >= slog.LevelWarn:
+ levelStr = "WARN "
+ case record.Level >= slog.LevelInfo:
+ levelStr = "INFO "
+ default:
+ levelStr = "DEBUG"
+ }
+
+ // 构建日志行
+ logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
+
+ // 处理其他属性
+ var attrs []string
+ record.Attrs(func(attr slog.Attr) bool {
+ attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
+ return true
+ })
+
+ if len(attrs) > 0 {
+ logLine += " " + strings.Join(attrs, " ")
+ }
+
+ // 写入日志
+ _, err := fmt.Fprintln(h.w, logLine)
+ return err
+}
+
+// WithAttrs 实现 slog.Handler 接口
+func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ h2 := *h
+ h2.attrs = append(h.attrs[:], attrs...)
+ return &h2
+}
+
+// WithGroup 实现 slog.Handler 接口
+func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
+ h2 := *h
+ h2.groups = append(h.groups[:], name)
+ return &h2
+}
+
+// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
+type MultiHandler struct {
+ handlers []slog.Handler
+}
+
+// Enabled 实现 slog.Handler 接口
+func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
+ // 如果任何一个处理器启用了该级别,则返回 true
+ for _, handler := range h.handlers {
+ if handler.Enabled(ctx, level) {
+ return true
+ }
+ }
+ return false
+}
+
+// Handle 实现 slog.Handler 接口
+func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
+ // 将记录发送到所有处理器
+ for _, handler := range h.handlers {
+ if handler.Enabled(ctx, record.Level) {
+ if err := handler.Handle(ctx, record); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// WithAttrs 实现 slog.Handler 接口
+func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
+ newHandlers := make([]slog.Handler, len(h.handlers))
+ for i, handler := range h.handlers {
+ newHandlers[i] = handler.WithAttrs(attrs)
+ }
+ return &MultiHandler{handlers: newHandlers}
+}
+
+// WithGroup 实现 slog.Handler 接口
+func (h *MultiHandler) WithGroup(name string) slog.Handler {
+ newHandlers := make([]slog.Handler, len(h.handlers))
+ for i, handler := range h.handlers {
+ newHandlers[i] = handler.WithGroup(name)
+ }
+ return &MultiHandler{handlers: newHandlers}
+}
+
+// SetupLogger 设置日志输出
+func SetupLogger(logLevel string, logFile string) error {
+ // 创建标准错误输出的处理器,使用自定义格式
+ stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
+ Level: logLevelMap[logLevel],
+ })
+
+ // 如果没有指定日志文件,则仅使用标准错误处理器
+ if logFile == "" {
+ slog.SetDefault(slog.New(stdHandler))
+ return nil
+ }
+
+ // 确保日志文件所在目录存在
+ logDir := filepath.Dir(logFile)
+ if err := os.MkdirAll(logDir, 0755); err != nil {
+ return err
+ }
+
+ // 打开日志文件,如果不存在则创建,追加写入模式
+ file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ if err != nil {
+ return err
+ }
+
+ // 创建文件输出的处理器,使用自定义格式
+ fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
+ Level: logLevelMap[logLevel],
+ })
+
+ // 创建多输出处理器
+ multiHandler := &MultiHandler{
+ handlers: []slog.Handler{stdHandler, fileHandler},
+ }
+
+ // 设置全局日志处理器
+ slog.SetDefault(slog.New(multiHandler))
+ return nil
+}
+
+// InitDefaultLogger 初始化默认日志处理器
+func InitDefaultLogger(level slog.Level) {
+ slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
+ Level: level,
+ })))
+}
diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go
index 87bb9f9..62f273b 100644
--- a/pkg/utils/utils_test.go
+++ b/pkg/utils/utils_test.go
@@ -167,4 +167,3 @@ func TestGetSessionName(t *testing.T) {
})
}
}
-
diff --git a/tools/main.go b/tools/main.go
index fd9c08c..6721a20 100644
--- a/tools/main.go
+++ b/tools/main.go
@@ -1,30 +1,30 @@
-package main
-
-import (
- "context"
- "os"
- "os/signal"
- "syscall"
-
- "github.com/ossrs/go-oryx-lib/logger"
- "github.com/ossrs/srs-bench/gb28181"
-)
-
-func main() {
- ctx := context.Background()
-
- var conf interface{}
- conf = gb28181.Parse(ctx)
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- sigs := make(chan os.Signal, 1)
- signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
- for sig := range sigs {
- logger.Wf(ctx, "Quit for signal %v", sig)
- cancel()
- }
- }()
-
- gb28181.Run(ctx, conf)
-}
+package main
+
+import (
+ "context"
+ "os"
+ "os/signal"
+ "syscall"
+
+ "github.com/ossrs/go-oryx-lib/logger"
+ "github.com/ossrs/srs-bench/gb28181"
+)
+
+func main() {
+ ctx := context.Background()
+
+ var conf interface{}
+ conf = gb28181.Parse(ctx)
+
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
+ for sig := range sigs {
+ logger.Wf(ctx, "Quit for signal %v", sig)
+ cancel()
+ }
+ }()
+
+ gb28181.Run(ctx, conf)
+}