From 156f07644d535493bf174a3514eace3cc52b231f Mon Sep 17 00:00:00 2001 From: "haibo.chen" <495810242@qq.com> Date: Wed, 15 Oct 2025 10:05:52 +0800 Subject: [PATCH] gofmt --- pkg/config/config_test.go | 3 +- pkg/db/media_server.go | 304 ++++++++++++------------- pkg/media/zlm.go | 118 +++++----- pkg/models/gb28181.go | 212 +++++++++--------- pkg/models/types.go | 160 +++++++------- pkg/models/types_test.go | 1 - pkg/service/auth.go | 184 ++++++++-------- pkg/service/auth_test.go | 1 - pkg/service/cascade.go | 34 +-- pkg/service/inbound.go | 342 ++++++++++++++--------------- pkg/service/ptz.go | 162 +++++++------- pkg/service/ptz_test.go | 3 +- pkg/service/stack/request.go | 136 ++++++------ pkg/service/stack/response.go | 82 +++---- pkg/service/uac.go | 244 ++++++++++----------- pkg/utils/logger.go | 402 +++++++++++++++++----------------- pkg/utils/utils_test.go | 1 - tools/main.go | 60 ++--- 18 files changed, 1222 insertions(+), 1227 deletions(-) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 546a9a6..c2cbbdb 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -133,7 +133,7 @@ func TestLoadConfigInvalid(t *testing.T) { func TestGetLocalIP(t *testing.T) { ip, err := GetLocalIP() - + // 在某些环境下可能没有网络接口,所以允许返回错误 if err != nil { t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err) @@ -152,4 +152,3 @@ func TestGetLocalIP(t *testing.T) { t.Logf("Local IP: %s", ip) } - diff --git a/pkg/db/media_server.go b/pkg/db/media_server.go index f2ca72e..5fa1c9d 100644 --- a/pkg/db/media_server.go +++ b/pkg/db/media_server.go @@ -1,152 +1,152 @@ -package db - -import ( - "database/sql" - "sync" - - "github.com/ossrs/srs-sip/pkg/models" - _ "modernc.org/sqlite" -) - -var ( - instance *MediaServerDB - once sync.Once -) - -type MediaServerDB struct { - models.MediaServerResponse - db *sql.DB -} - -// GetInstance 返回 MediaServerDB 的单例实例 -func GetInstance(dbPath string) (*MediaServerDB, error) { - var err error - once.Do(func() { - instance, err = NewMediaServerDB(dbPath) - }) - if err != nil { - return nil, err - } - return instance, nil -} - -func NewMediaServerDB(dbPath string) (*MediaServerDB, error) { - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, err - } - - // 创建媒体服务器表 - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS media_servers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - type TEXT NOT NULL, - name TEXT NOT NULL, - ip TEXT NOT NULL, - port INTEGER NOT NULL, - username TEXT, - password TEXT, - secret TEXT, - is_default INTEGER NOT NULL DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - `) - if err != nil { - return nil, err - } - - return &MediaServerDB{db: db}, nil -} - -// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器 -func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) { - var ms models.MediaServerResponse - err := m.db.QueryRow(` - SELECT id, name, type, ip, port, username, password, secret, is_default, created_at - FROM media_servers WHERE name = ? AND ip = ? - `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) - if err != nil { - return nil, err - } - return &ms, nil -} - -func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { - _, err := m.db.Exec(` - INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - `, name, serverType, ip, port, username, password, secret, isDefault) - return err -} - -// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新) -func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { - // 检查是否已存在 - existing, err := m.GetMediaServerByNameAndIP(name, ip) - if err == nil && existing != nil { - // 已存在,更新记录 - _, err = m.db.Exec(` - UPDATE media_servers - SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ? - WHERE name = ? AND ip = ? - `, serverType, port, username, password, secret, isDefault, name, ip) - return err - } - - // 不存在,插入新记录 - return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault) -} - -func (m *MediaServerDB) DeleteMediaServer(id int) error { - _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id) - return err -} - -func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) { - var ms models.MediaServerResponse - err := m.db.QueryRow(` - SELECT id, name, type, ip, port, username, password, secret, is_default, created_at - FROM media_servers WHERE id = ? - `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) - if err != nil { - return nil, err - } - return &ms, nil -} - -func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) { - rows, err := m.db.Query(` - SELECT id, name, type, ip, port, username, password, secret, is_default, created_at - FROM media_servers ORDER BY created_at DESC - `) - if err != nil { - return nil, err - } - defer rows.Close() - - var servers []models.MediaServerResponse - for rows.Next() { - var ms models.MediaServerResponse - err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) - if err != nil { - return nil, err - } - servers = append(servers, ms) - } - return servers, nil -} - -func (m *MediaServerDB) SetDefaultMediaServer(id int) error { - // 先将所有服务器设置为非默认 - if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil { - return err - } - - // 将指定ID的服务器设置为默认 - _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id) - return err -} - -func (m *MediaServerDB) Close() error { - return m.db.Close() -} +package db + +import ( + "database/sql" + "sync" + + "github.com/ossrs/srs-sip/pkg/models" + _ "modernc.org/sqlite" +) + +var ( + instance *MediaServerDB + once sync.Once +) + +type MediaServerDB struct { + models.MediaServerResponse + db *sql.DB +} + +// GetInstance 返回 MediaServerDB 的单例实例 +func GetInstance(dbPath string) (*MediaServerDB, error) { + var err error + once.Do(func() { + instance, err = NewMediaServerDB(dbPath) + }) + if err != nil { + return nil, err + } + return instance, nil +} + +func NewMediaServerDB(dbPath string) (*MediaServerDB, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + + // 创建媒体服务器表 + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS media_servers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + type TEXT NOT NULL, + name TEXT NOT NULL, + ip TEXT NOT NULL, + port INTEGER NOT NULL, + username TEXT, + password TEXT, + secret TEXT, + is_default INTEGER NOT NULL DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + return nil, err + } + + return &MediaServerDB{db: db}, nil +} + +// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器 +func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) { + var ms models.MediaServerResponse + err := m.db.QueryRow(` + SELECT id, name, type, ip, port, username, password, secret, is_default, created_at + FROM media_servers WHERE name = ? AND ip = ? + `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) + if err != nil { + return nil, err + } + return &ms, nil +} + +func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { + _, err := m.db.Exec(` + INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `, name, serverType, ip, port, username, password, secret, isDefault) + return err +} + +// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新) +func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { + // 检查是否已存在 + existing, err := m.GetMediaServerByNameAndIP(name, ip) + if err == nil && existing != nil { + // 已存在,更新记录 + _, err = m.db.Exec(` + UPDATE media_servers + SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ? + WHERE name = ? AND ip = ? + `, serverType, port, username, password, secret, isDefault, name, ip) + return err + } + + // 不存在,插入新记录 + return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault) +} + +func (m *MediaServerDB) DeleteMediaServer(id int) error { + _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id) + return err +} + +func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) { + var ms models.MediaServerResponse + err := m.db.QueryRow(` + SELECT id, name, type, ip, port, username, password, secret, is_default, created_at + FROM media_servers WHERE id = ? + `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) + if err != nil { + return nil, err + } + return &ms, nil +} + +func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) { + rows, err := m.db.Query(` + SELECT id, name, type, ip, port, username, password, secret, is_default, created_at + FROM media_servers ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var servers []models.MediaServerResponse + for rows.Next() { + var ms models.MediaServerResponse + err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) + if err != nil { + return nil, err + } + servers = append(servers, ms) + } + return servers, nil +} + +func (m *MediaServerDB) SetDefaultMediaServer(id int) error { + // 先将所有服务器设置为非默认 + if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil { + return err + } + + // 将指定ID的服务器设置为默认 + _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id) + return err +} + +func (m *MediaServerDB) Close() error { + return m.db.Close() +} diff --git a/pkg/media/zlm.go b/pkg/media/zlm.go index ebc6619..a2abc98 100644 --- a/pkg/media/zlm.go +++ b/pkg/media/zlm.go @@ -1,59 +1,59 @@ -package media - -import ( - "context" - - "github.com/ossrs/go-oryx-lib/errors" -) - -type Zlm struct { - Ctx context.Context - Schema string // The schema of ZLM, eg: http - Addr string // The address of ZLM, eg: localhost:8085 - Secret string // The secret of ZLM, eg: ZLMediaKit_secret -} - -// /index/api/openRtpServer -// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2 -func (z *Zlm) Publish(id, ssrc string) (int, error) { - - res := struct { - Code int `json:"code"` - Port int `json:"port"` - }{} - - if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil { - return 0, errors.Wrapf(err, "gb/v1/publish") - } - return res.Port, nil -} - -// /index/api/closeRtpServer -func (z *Zlm) Unpublish(id string) error { - res := struct { - Code int `json:"code"` - }{} - if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { - return errors.Wrapf(err, "gb/v1/publish") - } - return nil -} - -// /index/api/getMediaList -func (z *Zlm) GetStreamStatus(id string) (bool, error) { - res := struct { - Code int `json:"code"` - }{} - if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { - return false, errors.Wrapf(err, "gb/v1/publish") - } - return res.Code == 0, nil -} - -func (z *Zlm) GetAddr() string { - return z.Addr -} - -func (z *Zlm) GetWebRTCAddr(id string) string { - return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play" -} +package media + +import ( + "context" + + "github.com/ossrs/go-oryx-lib/errors" +) + +type Zlm struct { + Ctx context.Context + Schema string // The schema of ZLM, eg: http + Addr string // The address of ZLM, eg: localhost:8085 + Secret string // The secret of ZLM, eg: ZLMediaKit_secret +} + +// /index/api/openRtpServer +// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2 +func (z *Zlm) Publish(id, ssrc string) (int, error) { + + res := struct { + Code int `json:"code"` + Port int `json:"port"` + }{} + + if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil { + return 0, errors.Wrapf(err, "gb/v1/publish") + } + return res.Port, nil +} + +// /index/api/closeRtpServer +func (z *Zlm) Unpublish(id string) error { + res := struct { + Code int `json:"code"` + }{} + if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { + return errors.Wrapf(err, "gb/v1/publish") + } + return nil +} + +// /index/api/getMediaList +func (z *Zlm) GetStreamStatus(id string) (bool, error) { + res := struct { + Code int `json:"code"` + }{} + if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { + return false, errors.Wrapf(err, "gb/v1/publish") + } + return res.Code == 0, nil +} + +func (z *Zlm) GetAddr() string { + return z.Addr +} + +func (z *Zlm) GetWebRTCAddr(id string) string { + return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play" +} diff --git a/pkg/models/gb28181.go b/pkg/models/gb28181.go index e7c52e4..b8c79be 100644 --- a/pkg/models/gb28181.go +++ b/pkg/models/gb28181.go @@ -1,106 +1,106 @@ -package models - -import "encoding/xml" - -type Record struct { - DeviceID string `xml:"DeviceID" json:"device_id"` - Name string `xml:"Name" json:"name"` - FilePath string `xml:"FilePath" json:"file_path"` - Address string `xml:"Address" json:"address"` - StartTime string `xml:"StartTime" json:"start_time"` - EndTime string `xml:"EndTime" json:"end_time"` - Secrecy int `xml:"Secrecy" json:"secrecy"` - Type string `xml:"Type" json:"type"` -} - -// Example XML structure for channel info: -// -// -// 34020000001320000002 -// 209 -// UNIVIEW -// HIC6622-IR@X33-VF -// IPC-B2202.7.11.230222 -// CivilCode -//
Address
-// 1 -// 75015310072008100002 -// 0 -// 1 -// 0 -// ON -// 0.0000000 -// 0.0000000 -// -// 1 -// 6/4/2 -// 0 -// -//
- -type ChannelInfo struct { - DeviceID string `json:"device_id"` - ParentID string `json:"parent_id"` - Name string `json:"name"` - Manufacturer string `json:"manufacturer"` - Model string `json:"model"` - Owner string `json:"owner"` - CivilCode string `json:"civil_code"` - Address string `json:"address"` - Port int `json:"port"` - Parental int `json:"parental"` - SafetyWay int `json:"safety_way"` - RegisterWay int `json:"register_way"` - Secrecy int `json:"secrecy"` - IPAddress string `json:"ip_address"` - Status ChannelStatus `json:"status"` - Longitude float64 `json:"longitude"` - Latitude float64 `json:"latitude"` - Info struct { - PTZType int `json:"ptz_type"` - Resolution string `json:"resolution"` - DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8 - } `json:"info"` - - // Custom fields - Ssrc string `json:"ssrc"` -} - -type ChannelStatus string - -// BasicParam -// -// -// -// -// -// -// -// -// -// -// -// -// -// -type BasicParam struct { - Name string `xml:"Name"` - Expiration int `xml:"Expiration"` - HeartBeatInterval int `xml:"HeartBeatInterval"` - HeartBeatCount int `xml:"HeartBeatCount"` -} - -type XmlMessageInfo struct { - XMLName xml.Name - CmdType string - SN int - DeviceID string - DeviceName string - Manufacturer string - Model string - Channel string - DeviceList []ChannelInfo `xml:"DeviceList>Item"` - RecordList []*Record `xml:"RecordList>Item"` - BasicParam BasicParam `xml:"BasicParam"` - SumNum int -} +package models + +import "encoding/xml" + +type Record struct { + DeviceID string `xml:"DeviceID" json:"device_id"` + Name string `xml:"Name" json:"name"` + FilePath string `xml:"FilePath" json:"file_path"` + Address string `xml:"Address" json:"address"` + StartTime string `xml:"StartTime" json:"start_time"` + EndTime string `xml:"EndTime" json:"end_time"` + Secrecy int `xml:"Secrecy" json:"secrecy"` + Type string `xml:"Type" json:"type"` +} + +// Example XML structure for channel info: +// +// +// 34020000001320000002 +// 209 +// UNIVIEW +// HIC6622-IR@X33-VF +// IPC-B2202.7.11.230222 +// CivilCode +//
Address
+// 1 +// 75015310072008100002 +// 0 +// 1 +// 0 +// ON +// 0.0000000 +// 0.0000000 +// +// 1 +// 6/4/2 +// 0 +// +//
+ +type ChannelInfo struct { + DeviceID string `json:"device_id"` + ParentID string `json:"parent_id"` + Name string `json:"name"` + Manufacturer string `json:"manufacturer"` + Model string `json:"model"` + Owner string `json:"owner"` + CivilCode string `json:"civil_code"` + Address string `json:"address"` + Port int `json:"port"` + Parental int `json:"parental"` + SafetyWay int `json:"safety_way"` + RegisterWay int `json:"register_way"` + Secrecy int `json:"secrecy"` + IPAddress string `json:"ip_address"` + Status ChannelStatus `json:"status"` + Longitude float64 `json:"longitude"` + Latitude float64 `json:"latitude"` + Info struct { + PTZType int `json:"ptz_type"` + Resolution string `json:"resolution"` + DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8 + } `json:"info"` + + // Custom fields + Ssrc string `json:"ssrc"` +} + +type ChannelStatus string + +// BasicParam +// +// +// +// +// +// +// +// +// +// +// +// +// +// +type BasicParam struct { + Name string `xml:"Name"` + Expiration int `xml:"Expiration"` + HeartBeatInterval int `xml:"HeartBeatInterval"` + HeartBeatCount int `xml:"HeartBeatCount"` +} + +type XmlMessageInfo struct { + XMLName xml.Name + CmdType string + SN int + DeviceID string + DeviceName string + Manufacturer string + Model string + Channel string + DeviceList []ChannelInfo `xml:"DeviceList>Item"` + RecordList []*Record `xml:"RecordList>Item"` + BasicParam BasicParam `xml:"BasicParam"` + SumNum int +} diff --git a/pkg/models/types.go b/pkg/models/types.go index 6ab8a60..ba7e821 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -1,80 +1,80 @@ -package models - -type BaseRequest struct { - DeviceID string `json:"device_id"` - ChannelID string `json:"channel_id"` -} - -type InviteRequest struct { - BaseRequest - MediaServerId int `json:"media_server_id"` - PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download - SubStream int `json:"sub_stream"` - StartTime int64 `json:"start_time"` - EndTime int64 `json:"end_time"` -} - -type InviteResponse struct { - ChannelID string `json:"channel_id"` - URL string `json:"url"` -} - -type SessionRequest struct { - BaseRequest - URL string `json:"url"` -} - -type ByeRequest struct { - SessionRequest -} - -type PauseRequest struct { - SessionRequest -} - -type ResumeRequest struct { - SessionRequest -} - -type SpeedRequest struct { - SessionRequest - Speed float32 `json:"speed"` -} - -type PTZControlRequest struct { - BaseRequest - PTZ string `json:"ptz"` - Speed string `json:"speed"` -} - -type QueryRecordRequest struct { - BaseRequest - StartTime int64 `json:"start_time"` - EndTime int64 `json:"end_time"` -} - -type MediaServer struct { - Name string `json:"name"` - Type string `json:"type"` - IP string `json:"ip"` - Port int `json:"port"` - Username string `json:"username"` - Password string `json:"password"` - Secret string `json:"secret"` - IsDefault int `json:"is_default"` -} - -type MediaServerRequest struct { - MediaServer -} - -type MediaServerResponse struct { - MediaServer - ID int `json:"id"` - CreatedAt string `json:"created_at"` -} - -type CommonResponse struct { - Code int `json:"code"` - Data interface{} `json:"data"` -} +package models + +type BaseRequest struct { + DeviceID string `json:"device_id"` + ChannelID string `json:"channel_id"` +} + +type InviteRequest struct { + BaseRequest + MediaServerId int `json:"media_server_id"` + PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download + SubStream int `json:"sub_stream"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` +} + +type InviteResponse struct { + ChannelID string `json:"channel_id"` + URL string `json:"url"` +} + +type SessionRequest struct { + BaseRequest + URL string `json:"url"` +} + +type ByeRequest struct { + SessionRequest +} + +type PauseRequest struct { + SessionRequest +} + +type ResumeRequest struct { + SessionRequest +} + +type SpeedRequest struct { + SessionRequest + Speed float32 `json:"speed"` +} + +type PTZControlRequest struct { + BaseRequest + PTZ string `json:"ptz"` + Speed string `json:"speed"` +} + +type QueryRecordRequest struct { + BaseRequest + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` +} + +type MediaServer struct { + Name string `json:"name"` + Type string `json:"type"` + IP string `json:"ip"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + Secret string `json:"secret"` + IsDefault int `json:"is_default"` +} + +type MediaServerRequest struct { + MediaServer +} + +type MediaServerResponse struct { + MediaServer + ID int `json:"id"` + CreatedAt string `json:"created_at"` +} + +type CommonResponse struct { + Code int `json:"code"` + Data interface{} `json:"data"` +} diff --git a/pkg/models/types_test.go b/pkg/models/types_test.go index 7d5f7e5..bd873bb 100644 --- a/pkg/models/types_test.go +++ b/pkg/models/types_test.go @@ -335,4 +335,3 @@ func TestCommonResponseWithDifferentDataTypes(t *testing.T) { }) } } - diff --git a/pkg/service/auth.go b/pkg/service/auth.go index 8f1538b..06fc879 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -1,92 +1,92 @@ -package service - -import ( - "crypto/md5" - "crypto/rand" - "encoding/hex" - "fmt" - "strings" -) - -// AuthInfo 存储解析后的认证信息 -type AuthInfo struct { - Username string - Realm string - Nonce string - URI string - Response string - Algorithm string - Method string -} - -// GenerateNonce 生成随机 nonce 字符串 -func GenerateNonce() string { - b := make([]byte, 16) - rand.Read(b) - return fmt.Sprintf("%x", b) -} - -// ParseAuthorization 解析 SIP Authorization 头 -// Authorization: Digest username="34020000001320000001",realm="3402000000", -// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000", -// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5 -func ParseAuthorization(auth string) *AuthInfo { - auth = strings.TrimPrefix(auth, "Digest ") - parts := strings.Split(auth, ",") - result := &AuthInfo{} - - for _, part := range parts { - part = strings.TrimSpace(part) - if !strings.Contains(part, "=") { - continue - } - - kv := strings.SplitN(part, "=", 2) - key := strings.TrimSpace(kv[0]) - value := strings.Trim(strings.TrimSpace(kv[1]), "\"") - - switch key { - case "username": - result.Username = value - case "realm": - result.Realm = value - case "nonce": - result.Nonce = value - case "uri": - result.URI = value - case "response": - result.Response = value - case "algorithm": - result.Algorithm = value - } - } - - return result -} - -// ValidateAuth 验证 SIP 认证信息 -func ValidateAuth(authInfo *AuthInfo, password string) bool { - if authInfo == nil { - return false - } - - // 默认方法为 REGISTER - method := "REGISTER" - if authInfo.Method != "" { - method = authInfo.Method - } - - // 计算 MD5 哈希 - ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password) - ha2 := md5Hex(method + ":" + authInfo.URI) - correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2) - - return authInfo.Response == correctResponse -} - -// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串 -func md5Hex(s string) string { - hash := md5.New() - hash.Write([]byte(s)) - return hex.EncodeToString(hash.Sum(nil)) -} \ No newline at end of file +package service + +import ( + "crypto/md5" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" +) + +// AuthInfo 存储解析后的认证信息 +type AuthInfo struct { + Username string + Realm string + Nonce string + URI string + Response string + Algorithm string + Method string +} + +// GenerateNonce 生成随机 nonce 字符串 +func GenerateNonce() string { + b := make([]byte, 16) + rand.Read(b) + return fmt.Sprintf("%x", b) +} + +// ParseAuthorization 解析 SIP Authorization 头 +// Authorization: Digest username="34020000001320000001",realm="3402000000", +// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000", +// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5 +func ParseAuthorization(auth string) *AuthInfo { + auth = strings.TrimPrefix(auth, "Digest ") + parts := strings.Split(auth, ",") + result := &AuthInfo{} + + for _, part := range parts { + part = strings.TrimSpace(part) + if !strings.Contains(part, "=") { + continue + } + + kv := strings.SplitN(part, "=", 2) + key := strings.TrimSpace(kv[0]) + value := strings.Trim(strings.TrimSpace(kv[1]), "\"") + + switch key { + case "username": + result.Username = value + case "realm": + result.Realm = value + case "nonce": + result.Nonce = value + case "uri": + result.URI = value + case "response": + result.Response = value + case "algorithm": + result.Algorithm = value + } + } + + return result +} + +// ValidateAuth 验证 SIP 认证信息 +func ValidateAuth(authInfo *AuthInfo, password string) bool { + if authInfo == nil { + return false + } + + // 默认方法为 REGISTER + method := "REGISTER" + if authInfo.Method != "" { + method = authInfo.Method + } + + // 计算 MD5 哈希 + ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password) + ha2 := md5Hex(method + ":" + authInfo.URI) + correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2) + + return authInfo.Response == correctResponse +} + +// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串 +func md5Hex(s string) string { + hash := md5.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/service/auth_test.go b/pkg/service/auth_test.go index f0f0835..6a4f3b6 100644 --- a/pkg/service/auth_test.go +++ b/pkg/service/auth_test.go @@ -343,4 +343,3 @@ func TestParseAuthorizationQuotedValues(t *testing.T) { t.Logf("Realm value: '%s'", result.Realm) } } - diff --git a/pkg/service/cascade.go b/pkg/service/cascade.go index e7360c6..d214508 100644 --- a/pkg/service/cascade.go +++ b/pkg/service/cascade.go @@ -1,17 +1,17 @@ -package service - -import ( - "context" - - "github.com/emiago/sipgo" - "github.com/ossrs/srs-sip/pkg/config" -) - -type Cascade struct { - ua *sipgo.UserAgent - sipCli *sipgo.Client - sipSvr *sipgo.Server - - ctx context.Context - conf *config.MainConfig -} +package service + +import ( + "context" + + "github.com/emiago/sipgo" + "github.com/ossrs/srs-sip/pkg/config" +) + +type Cascade struct { + ua *sipgo.UserAgent + sipCli *sipgo.Client + sipSvr *sipgo.Server + + ctx context.Context + conf *config.MainConfig +} diff --git a/pkg/service/inbound.go b/pkg/service/inbound.go index 3114a65..7737ac5 100644 --- a/pkg/service/inbound.go +++ b/pkg/service/inbound.go @@ -1,171 +1,171 @@ -package service - -import ( - "bytes" - "encoding/xml" - "fmt" - "log/slog" - "net" - "net/http" - "strconv" - - "github.com/emiago/sipgo/sip" - "github.com/ossrs/srs-sip/pkg/models" - "github.com/ossrs/srs-sip/pkg/service/stack" - "golang.org/x/net/html/charset" -) - -const GB28181_ID_LENGTH = 20 - -func (s *UAS) isSameIP(addr1, addr2 string) bool { - ip1, _, err1 := net.SplitHostPort(addr1) - ip2, _, err2 := net.SplitHostPort(addr2) - - // 如果解析出错,回退到完整字符串比较 - if err1 != nil || err2 != nil { - return addr1 == addr2 - } - - return ip1 == ip2 -} - -func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) { - id := req.From().Address.User - if len(id) != GB28181_ID_LENGTH { - slog.Error("invalid device ID") - return - } - - slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String())) - - if s.conf.GB28181.Auth.Enable { - // Check if Authorization header exists - authHeader := req.GetHeaders("Authorization") - - // If no Authorization header, send 401 response to request authentication - if len(authHeader) == 0 { - nonce := GenerateNonce() - resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm) - _ = tx.Respond(resp) - return - } - - // Validate Authorization - authInfo := ParseAuthorization(authHeader[0].Value()) - if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) { - slog.Error("auth failed", "device_id", id, "source", req.Source()) - s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx) - return - } - } - - isUnregister := false - if exps := req.GetHeaders("Expires"); len(exps) > 0 { - exp := exps[0] - expSec, err := strconv.ParseInt(exp.Value(), 10, 32) - if err != nil { - slog.Error("parse expires header error", "error", err.Error()) - return - } - if expSec == 0 { - isUnregister = true - } - } else { - slog.Error("empty expires header") - return - } - - if isUnregister { - DM.RemoveDevice(id) - slog.Warn("Device unregistered", "device_id", id) - return - } else { - if d, ok := DM.GetDevice(id); !ok { - DM.AddDevice(id, &DeviceInfo{ - DeviceID: id, - SourceAddr: req.Source(), - NetworkType: req.Transport(), - }) - s.respondRegister(req, http.StatusOK, "OK", tx) - slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source())) - - go s.ConfigDownload(id) - go s.Catalog(id) - } else { - if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) { - slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source()) - // TODO: 如果ID重复,应采用虚拟ID - s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx) - } else { - d.SourceAddr = req.Source() - d.NetworkType = req.Transport() - DM.UpdateDevice(id, d) - s.respondRegister(req, http.StatusOK, "OK", tx) - - slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source())) - } - } - } -} - -func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) { - res := stack.NewRegisterResponse(req, code, reason) - _ = tx.Respond(res) - -} - -func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) { - id := req.From().Address.User - if len(id) != 20 { - slog.Error("invalid device ID", "request", req.String()) - } - - slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String())) - - temp := &models.XmlMessageInfo{} - decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body()))) - decoder.CharsetReader = charset.NewReaderLabel - if err := decoder.Decode(temp); err != nil { - slog.Error("decode message error", "error", err.Error(), "message", req.Body()) - } - - slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source())) - - var body string - switch temp.CmdType { - case "Keepalive": - if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online { - // 更新设备心跳时间 - DM.UpdateDeviceHeartbeat(temp.DeviceID) - } else { - tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)) - return - } - case "SensorCatalog": // 兼容宇视,非国标 - case "Catalog": - DM.UpdateChannels(temp.DeviceID, temp.DeviceList...) - //go s.AutoInvite(temp.DeviceID, temp.DeviceList...) - case "ConfigDownload": - DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam) - case "Alarm": - slog.Info("Alarm") - case "RecordInfo": - // 从 recordQueryResults 中获取对应通道的结果通道 - if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok { - // 发送查询结果 - resultChan := ch.(chan *models.XmlMessageInfo) - resultChan <- temp - } - default: - slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType) - response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil) - tx.Respond(response) - return - } - tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body))) -} - -func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) { - slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String())) - tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil)) -} +package service + +import ( + "bytes" + "encoding/xml" + "fmt" + "log/slog" + "net" + "net/http" + "strconv" + + "github.com/emiago/sipgo/sip" + "github.com/ossrs/srs-sip/pkg/models" + "github.com/ossrs/srs-sip/pkg/service/stack" + "golang.org/x/net/html/charset" +) + +const GB28181_ID_LENGTH = 20 + +func (s *UAS) isSameIP(addr1, addr2 string) bool { + ip1, _, err1 := net.SplitHostPort(addr1) + ip2, _, err2 := net.SplitHostPort(addr2) + + // 如果解析出错,回退到完整字符串比较 + if err1 != nil || err2 != nil { + return addr1 == addr2 + } + + return ip1 == ip2 +} + +func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) { + id := req.From().Address.User + if len(id) != GB28181_ID_LENGTH { + slog.Error("invalid device ID") + return + } + + slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String())) + + if s.conf.GB28181.Auth.Enable { + // Check if Authorization header exists + authHeader := req.GetHeaders("Authorization") + + // If no Authorization header, send 401 response to request authentication + if len(authHeader) == 0 { + nonce := GenerateNonce() + resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm) + _ = tx.Respond(resp) + return + } + + // Validate Authorization + authInfo := ParseAuthorization(authHeader[0].Value()) + if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) { + slog.Error("auth failed", "device_id", id, "source", req.Source()) + s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx) + return + } + } + + isUnregister := false + if exps := req.GetHeaders("Expires"); len(exps) > 0 { + exp := exps[0] + expSec, err := strconv.ParseInt(exp.Value(), 10, 32) + if err != nil { + slog.Error("parse expires header error", "error", err.Error()) + return + } + if expSec == 0 { + isUnregister = true + } + } else { + slog.Error("empty expires header") + return + } + + if isUnregister { + DM.RemoveDevice(id) + slog.Warn("Device unregistered", "device_id", id) + return + } else { + if d, ok := DM.GetDevice(id); !ok { + DM.AddDevice(id, &DeviceInfo{ + DeviceID: id, + SourceAddr: req.Source(), + NetworkType: req.Transport(), + }) + s.respondRegister(req, http.StatusOK, "OK", tx) + slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source())) + + go s.ConfigDownload(id) + go s.Catalog(id) + } else { + if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) { + slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source()) + // TODO: 如果ID重复,应采用虚拟ID + s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx) + } else { + d.SourceAddr = req.Source() + d.NetworkType = req.Transport() + DM.UpdateDevice(id, d) + s.respondRegister(req, http.StatusOK, "OK", tx) + + slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source())) + } + } + } +} + +func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) { + res := stack.NewRegisterResponse(req, code, reason) + _ = tx.Respond(res) + +} + +func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) { + id := req.From().Address.User + if len(id) != 20 { + slog.Error("invalid device ID", "request", req.String()) + } + + slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String())) + + temp := &models.XmlMessageInfo{} + decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body()))) + decoder.CharsetReader = charset.NewReaderLabel + if err := decoder.Decode(temp); err != nil { + slog.Error("decode message error", "error", err.Error(), "message", req.Body()) + } + + slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source())) + + var body string + switch temp.CmdType { + case "Keepalive": + if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online { + // 更新设备心跳时间 + DM.UpdateDeviceHeartbeat(temp.DeviceID) + } else { + tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)) + return + } + case "SensorCatalog": // 兼容宇视,非国标 + case "Catalog": + DM.UpdateChannels(temp.DeviceID, temp.DeviceList...) + //go s.AutoInvite(temp.DeviceID, temp.DeviceList...) + case "ConfigDownload": + DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam) + case "Alarm": + slog.Info("Alarm") + case "RecordInfo": + // 从 recordQueryResults 中获取对应通道的结果通道 + if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok { + // 发送查询结果 + resultChan := ch.(chan *models.XmlMessageInfo) + resultChan <- temp + } + default: + slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType) + response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil) + tx.Respond(response) + return + } + tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body))) +} + +func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) { + slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String())) + tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil)) +} diff --git a/pkg/service/ptz.go b/pkg/service/ptz.go index e561643..d62a4fa 100644 --- a/pkg/service/ptz.go +++ b/pkg/service/ptz.go @@ -1,81 +1,81 @@ -package service - -import "fmt" - -var ( - ptzCmdMap = map[string]uint8{ - "stop": 0, - "right": 1, - "left": 2, - "down": 4, - "downright": 5, - "downleft": 6, - "up": 8, - "upright": 9, - "upleft": 10, - "zoomin": 16, - "zoomout": 32, - } - - ptzSpeedMap = map[string]uint8{ - "1": 25, - "2": 50, - "3": 75, - "4": 100, - "5": 125, - "6": 150, - "7": 175, - "8": 200, - "9": 225, - "10": 255, - } - - defaultSpeed uint8 = 125 -) - -func getPTZSpeed(speed string) uint8 { - if v, ok := ptzSpeedMap[speed]; ok { - return v - } - return defaultSpeed -} - -func toPTZCmd(cmdName, speed string) (string, error) { - cmdCode, ok := ptzCmdMap[cmdName] - if !ok { - return "", fmt.Errorf("invalid ptz command: %q", cmdName) - } - - speedValue := getPTZSpeed(speed) - - var horizontalSpeed, verticalSpeed, zSpeed uint8 - - switch cmdName { - case "left", "right": - horizontalSpeed = speedValue - verticalSpeed = 0 - case "up", "down": - verticalSpeed = speedValue - horizontalSpeed = 0 - case "upleft", "upright", "downleft", "downright": - verticalSpeed = speedValue - horizontalSpeed = speedValue - case "zoomin", "zoomout": - zSpeed = speedValue << 4 // zoom速度在高4位 - default: - horizontalSpeed = 0 - verticalSpeed = 0 - zSpeed = 0 - } - - sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed) - checksum := uint8(sum % 256) - - return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X", - cmdCode, - horizontalSpeed, - verticalSpeed, - zSpeed, - checksum, - ), nil -} +package service + +import "fmt" + +var ( + ptzCmdMap = map[string]uint8{ + "stop": 0, + "right": 1, + "left": 2, + "down": 4, + "downright": 5, + "downleft": 6, + "up": 8, + "upright": 9, + "upleft": 10, + "zoomin": 16, + "zoomout": 32, + } + + ptzSpeedMap = map[string]uint8{ + "1": 25, + "2": 50, + "3": 75, + "4": 100, + "5": 125, + "6": 150, + "7": 175, + "8": 200, + "9": 225, + "10": 255, + } + + defaultSpeed uint8 = 125 +) + +func getPTZSpeed(speed string) uint8 { + if v, ok := ptzSpeedMap[speed]; ok { + return v + } + return defaultSpeed +} + +func toPTZCmd(cmdName, speed string) (string, error) { + cmdCode, ok := ptzCmdMap[cmdName] + if !ok { + return "", fmt.Errorf("invalid ptz command: %q", cmdName) + } + + speedValue := getPTZSpeed(speed) + + var horizontalSpeed, verticalSpeed, zSpeed uint8 + + switch cmdName { + case "left", "right": + horizontalSpeed = speedValue + verticalSpeed = 0 + case "up", "down": + verticalSpeed = speedValue + horizontalSpeed = 0 + case "upleft", "upright", "downleft", "downright": + verticalSpeed = speedValue + horizontalSpeed = speedValue + case "zoomin", "zoomout": + zSpeed = speedValue << 4 // zoom速度在高4位 + default: + horizontalSpeed = 0 + verticalSpeed = 0 + zSpeed = 0 + } + + sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed) + checksum := uint8(sum % 256) + + return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X", + cmdCode, + horizontalSpeed, + verticalSpeed, + zSpeed, + checksum, + ), nil +} diff --git a/pkg/service/ptz_test.go b/pkg/service/ptz_test.go index 14e5e88..a715a0f 100644 --- a/pkg/service/ptz_test.go +++ b/pkg/service/ptz_test.go @@ -141,7 +141,7 @@ func TestToPTZCmdSpecificCases(t *testing.T) { func TestToPTZCmdWithDifferentSpeeds(t *testing.T) { speeds := []string{"1", "5", "10"} - + for _, speed := range speeds { t.Run("Right with speed "+speed, func(t *testing.T) { result, err := toPTZCmd("right", speed) @@ -196,4 +196,3 @@ func TestPTZSpeedMap(t *testing.T) { }) } } - diff --git a/pkg/service/stack/request.go b/pkg/service/stack/request.go index ebb7e54..c9567ae 100644 --- a/pkg/service/stack/request.go +++ b/pkg/service/stack/request.go @@ -1,68 +1,68 @@ -package stack - -import ( - "github.com/emiago/sipgo/sip" - "github.com/ossrs/go-oryx-lib/errors" -) - -type OutboundConfig struct { - Transport string - Via string - From string - To string -} - -func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) { - if len(conf.From) != 20 || len(conf.To) != 20 { - return nil, errors.Errorf("From or To length is not 20") - } - - dest := conf.Via - to := sip.Uri{User: conf.To, Host: conf.To[:10]} - from := &sip.Uri{User: conf.From, Host: conf.From[:10]} - - fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()} - fromHeader.Params.Add("tag", sip.GenerateTagN(16)) - - req := sip.NewRequest(method, to) - req.AppendHeader(fromHeader) - req.AppendHeader(&sip.ToHeader{Address: to}) - req.AppendHeader(&sip.ContactHeader{Address: *from}) - req.AppendHeader(sip.NewHeader("Max-Forwards", "70")) - req.SetBody(body) - req.SetDestination(dest) - req.SetTransport(conf.Transport) - - return req, nil -} - -func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) { - req, err := NewRequest(sip.REGISTER, nil, conf) - if err != nil { - return nil, err - } - req.AppendHeader(sip.NewHeader("Expires", "3600")) - - return req, nil -} - -func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) { - req, err := NewRequest(sip.INVITE, body, conf) - if err != nil { - return nil, err - } - req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) - req.AppendHeader(sip.NewHeader("Subject", subject)) - - return req, nil -} - -func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) { - req, err := NewRequest(sip.MESSAGE, body, conf) - if err != nil { - return nil, err - } - req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml")) - - return req, nil -} +package stack + +import ( + "github.com/emiago/sipgo/sip" + "github.com/ossrs/go-oryx-lib/errors" +) + +type OutboundConfig struct { + Transport string + Via string + From string + To string +} + +func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) { + if len(conf.From) != 20 || len(conf.To) != 20 { + return nil, errors.Errorf("From or To length is not 20") + } + + dest := conf.Via + to := sip.Uri{User: conf.To, Host: conf.To[:10]} + from := &sip.Uri{User: conf.From, Host: conf.From[:10]} + + fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()} + fromHeader.Params.Add("tag", sip.GenerateTagN(16)) + + req := sip.NewRequest(method, to) + req.AppendHeader(fromHeader) + req.AppendHeader(&sip.ToHeader{Address: to}) + req.AppendHeader(&sip.ContactHeader{Address: *from}) + req.AppendHeader(sip.NewHeader("Max-Forwards", "70")) + req.SetBody(body) + req.SetDestination(dest) + req.SetTransport(conf.Transport) + + return req, nil +} + +func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) { + req, err := NewRequest(sip.REGISTER, nil, conf) + if err != nil { + return nil, err + } + req.AppendHeader(sip.NewHeader("Expires", "3600")) + + return req, nil +} + +func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) { + req, err := NewRequest(sip.INVITE, body, conf) + if err != nil { + return nil, err + } + req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) + req.AppendHeader(sip.NewHeader("Subject", subject)) + + return req, nil +} + +func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) { + req, err := NewRequest(sip.MESSAGE, body, conf) + if err != nil { + return nil, err + } + req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml")) + + return req, nil +} diff --git a/pkg/service/stack/response.go b/pkg/service/stack/response.go index c3d4149..3009e09 100644 --- a/pkg/service/stack/response.go +++ b/pkg/service/stack/response.go @@ -1,41 +1,41 @@ -package stack - -import ( - "fmt" - "time" - - "github.com/emiago/sipgo/sip" -) - -const TIME_LAYOUT = "2024-01-01T00:00:00" -const EXPIRES_TIME = 3600 - -func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { - resp := sip.NewResponseFromRequest(req, code, reason, nil) - - newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()} - newTo.Params.Add("tag", sip.GenerateTagN(10)) - - resp.ReplaceHeader(newTo) - resp.RemoveHeader("Allow") - - return resp -} - -func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { - resp := newResponse(req, code, reason) - - expires := sip.ExpiresHeader(EXPIRES_TIME) - resp.AppendHeader(&expires) - resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT))) - - return resp -} - -func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response { - resp := newResponse(req, code, reason) - - resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce))) - - return resp -} +package stack + +import ( + "fmt" + "time" + + "github.com/emiago/sipgo/sip" +) + +const TIME_LAYOUT = "2024-01-01T00:00:00" +const EXPIRES_TIME = 3600 + +func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { + resp := sip.NewResponseFromRequest(req, code, reason, nil) + + newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()} + newTo.Params.Add("tag", sip.GenerateTagN(10)) + + resp.ReplaceHeader(newTo) + resp.RemoveHeader("Allow") + + return resp +} + +func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { + resp := newResponse(req, code, reason) + + expires := sip.ExpiresHeader(EXPIRES_TIME) + resp.AppendHeader(&expires) + resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT))) + + return resp +} + +func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response { + resp := newResponse(req, code, reason) + + resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce))) + + return resp +} diff --git a/pkg/service/uac.go b/pkg/service/uac.go index 066f835..dfd773c 100644 --- a/pkg/service/uac.go +++ b/pkg/service/uac.go @@ -1,122 +1,122 @@ -package service - -import ( - "context" - "fmt" - "log/slog" - - "github.com/emiago/sipgo" - "github.com/emiago/sipgo/sip" - "github.com/ossrs/go-oryx-lib/errors" - "github.com/ossrs/srs-sip/pkg/config" - "github.com/ossrs/srs-sip/pkg/service/stack" -) - -const ( - UserAgent = "SRS-SIP/1.0" -) - -type UAC struct { - *Cascade - - SN uint32 - LocalIP string -} - -func NewUac() *UAC { - ip, err := config.GetLocalIP() - if err != nil { - slog.Error("get local ip failed", "error", err) - return nil - } - - c := &UAC{ - Cascade: &Cascade{}, - LocalIP: ip, - } - return c -} - -func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error { - var err error - - c.ctx = context.Background() - c.conf = r0.(*config.MainConfig) - - if agent == nil { - ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent)) - if err != nil { - return err - } - agent = ua - } - - c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP)) - if err != nil { - return err - } - - c.sipSvr, err = sipgo.NewServer(agent) - if err != nil { - return err - } - - c.sipSvr.OnInvite(c.onInvite) - c.sipSvr.OnBye(c.onBye) - c.sipSvr.OnMessage(c.onMessage) - - go c.doRegister() - - return nil -} - -func (c *UAC) Stop() { - // TODO: 断开所有当前连接 - c.sipCli.Close() - c.sipSvr.Close() -} - -func (c *UAC) doRegister() error { - r, _ := stack.NewRegisterRequest(stack.OutboundConfig{ - From: "34020000001110000001", - To: "34020000002000000001", - Transport: "UDP", - Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port), - }) - tx, err := c.sipCli.TransactionRequest(c.ctx, r) - if err != nil { - return errors.Wrapf(err, "transaction request error") - } - - rs, _ := c.getResponse(tx) - slog.Info("register response", "response", rs.String()) - return nil -} - -func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) { - switch req.Method { - case "INVITE": - c.onInvite(req, tx) - } -} - -func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) { - slog.Debug("onInvite") -} - -func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) { - slog.Debug("onBye") -} - -func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) { - slog.Debug("onMessage", "request", req.String()) -} - -func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) { - select { - case <-tx.Done(): - return nil, fmt.Errorf("transaction died") - case res := <-tx.Responses(): - return res, nil - } -} +package service + +import ( + "context" + "fmt" + "log/slog" + + "github.com/emiago/sipgo" + "github.com/emiago/sipgo/sip" + "github.com/ossrs/go-oryx-lib/errors" + "github.com/ossrs/srs-sip/pkg/config" + "github.com/ossrs/srs-sip/pkg/service/stack" +) + +const ( + UserAgent = "SRS-SIP/1.0" +) + +type UAC struct { + *Cascade + + SN uint32 + LocalIP string +} + +func NewUac() *UAC { + ip, err := config.GetLocalIP() + if err != nil { + slog.Error("get local ip failed", "error", err) + return nil + } + + c := &UAC{ + Cascade: &Cascade{}, + LocalIP: ip, + } + return c +} + +func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error { + var err error + + c.ctx = context.Background() + c.conf = r0.(*config.MainConfig) + + if agent == nil { + ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent)) + if err != nil { + return err + } + agent = ua + } + + c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP)) + if err != nil { + return err + } + + c.sipSvr, err = sipgo.NewServer(agent) + if err != nil { + return err + } + + c.sipSvr.OnInvite(c.onInvite) + c.sipSvr.OnBye(c.onBye) + c.sipSvr.OnMessage(c.onMessage) + + go c.doRegister() + + return nil +} + +func (c *UAC) Stop() { + // TODO: 断开所有当前连接 + c.sipCli.Close() + c.sipSvr.Close() +} + +func (c *UAC) doRegister() error { + r, _ := stack.NewRegisterRequest(stack.OutboundConfig{ + From: "34020000001110000001", + To: "34020000002000000001", + Transport: "UDP", + Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port), + }) + tx, err := c.sipCli.TransactionRequest(c.ctx, r) + if err != nil { + return errors.Wrapf(err, "transaction request error") + } + + rs, _ := c.getResponse(tx) + slog.Info("register response", "response", rs.String()) + return nil +} + +func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) { + switch req.Method { + case "INVITE": + c.onInvite(req, tx) + } +} + +func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) { + slog.Debug("onInvite") +} + +func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) { + slog.Debug("onBye") +} + +func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) { + slog.Debug("onMessage", "request", req.String()) +} + +func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) { + select { + case <-tx.Done(): + return nil, fmt.Errorf("transaction died") + case res := <-tx.Responses(): + return res, nil + } +} diff --git a/pkg/utils/logger.go b/pkg/utils/logger.go index 3f9e5b1..f01df8d 100644 --- a/pkg/utils/logger.go +++ b/pkg/utils/logger.go @@ -1,201 +1,201 @@ -package utils - -import ( - "context" - "fmt" - "io" - "log/slog" - "os" - "path/filepath" - "strings" - "sync" -) - -var logLevelMap = map[string]slog.Level{ - "debug": slog.LevelDebug, - "info": slog.LevelInfo, - "warn": slog.LevelWarn, - "error": slog.LevelError, -} - -// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志 -type CustomFormatHandler struct { - mu sync.Mutex - w io.Writer - level slog.Level - attrs []slog.Attr - groups []string -} - -// NewCustomFormatHandler 创建一个新的自定义格式处理器 -func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler { - if opts == nil { - opts = &slog.HandlerOptions{} - } - - // 获取日志级别,如果opts.Level是nil则默认为Info - var level slog.Level - if opts.Level != nil { - level = opts.Level.Level() - } else { - level = slog.LevelInfo - } - - return &CustomFormatHandler{ - w: w, - level: level, - } -} - -// Enabled 实现 slog.Handler 接口 -func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool { - return level >= h.level -} - -// Handle 实现 slog.Handler 接口,以自定义格式输出日志 -func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error { - h.mu.Lock() - defer h.mu.Unlock() - - // 时间格式 - timeStr := record.Time.Format("2006-01-02 15:04:05.000") - - // 日志级别 - var levelStr string - switch { - case record.Level >= slog.LevelError: - levelStr = "ERROR" - case record.Level >= slog.LevelWarn: - levelStr = "WARN " - case record.Level >= slog.LevelInfo: - levelStr = "INFO " - default: - levelStr = "DEBUG" - } - - // 构建日志行 - logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message) - - // 处理其他属性 - var attrs []string - record.Attrs(func(attr slog.Attr) bool { - attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value)) - return true - }) - - if len(attrs) > 0 { - logLine += " " + strings.Join(attrs, " ") - } - - // 写入日志 - _, err := fmt.Fprintln(h.w, logLine) - return err -} - -// WithAttrs 实现 slog.Handler 接口 -func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - h2 := *h - h2.attrs = append(h.attrs[:], attrs...) - return &h2 -} - -// WithGroup 实现 slog.Handler 接口 -func (h *CustomFormatHandler) WithGroup(name string) slog.Handler { - h2 := *h - h2.groups = append(h.groups[:], name) - return &h2 -} - -// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器 -type MultiHandler struct { - handlers []slog.Handler -} - -// Enabled 实现 slog.Handler 接口 -func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { - // 如果任何一个处理器启用了该级别,则返回 true - for _, handler := range h.handlers { - if handler.Enabled(ctx, level) { - return true - } - } - return false -} - -// Handle 实现 slog.Handler 接口 -func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { - // 将记录发送到所有处理器 - for _, handler := range h.handlers { - if handler.Enabled(ctx, record.Level) { - if err := handler.Handle(ctx, record); err != nil { - return err - } - } - } - return nil -} - -// WithAttrs 实现 slog.Handler 接口 -func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - newHandlers := make([]slog.Handler, len(h.handlers)) - for i, handler := range h.handlers { - newHandlers[i] = handler.WithAttrs(attrs) - } - return &MultiHandler{handlers: newHandlers} -} - -// WithGroup 实现 slog.Handler 接口 -func (h *MultiHandler) WithGroup(name string) slog.Handler { - newHandlers := make([]slog.Handler, len(h.handlers)) - for i, handler := range h.handlers { - newHandlers[i] = handler.WithGroup(name) - } - return &MultiHandler{handlers: newHandlers} -} - -// SetupLogger 设置日志输出 -func SetupLogger(logLevel string, logFile string) error { - // 创建标准错误输出的处理器,使用自定义格式 - stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ - Level: logLevelMap[logLevel], - }) - - // 如果没有指定日志文件,则仅使用标准错误处理器 - if logFile == "" { - slog.SetDefault(slog.New(stdHandler)) - return nil - } - - // 确保日志文件所在目录存在 - logDir := filepath.Dir(logFile) - if err := os.MkdirAll(logDir, 0755); err != nil { - return err - } - - // 打开日志文件,如果不存在则创建,追加写入模式 - file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err != nil { - return err - } - - // 创建文件输出的处理器,使用自定义格式 - fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{ - Level: logLevelMap[logLevel], - }) - - // 创建多输出处理器 - multiHandler := &MultiHandler{ - handlers: []slog.Handler{stdHandler, fileHandler}, - } - - // 设置全局日志处理器 - slog.SetDefault(slog.New(multiHandler)) - return nil -} - -// InitDefaultLogger 初始化默认日志处理器 -func InitDefaultLogger(level slog.Level) { - slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }))) -} +package utils + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" +) + +var logLevelMap = map[string]slog.Level{ + "debug": slog.LevelDebug, + "info": slog.LevelInfo, + "warn": slog.LevelWarn, + "error": slog.LevelError, +} + +// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志 +type CustomFormatHandler struct { + mu sync.Mutex + w io.Writer + level slog.Level + attrs []slog.Attr + groups []string +} + +// NewCustomFormatHandler 创建一个新的自定义格式处理器 +func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler { + if opts == nil { + opts = &slog.HandlerOptions{} + } + + // 获取日志级别,如果opts.Level是nil则默认为Info + var level slog.Level + if opts.Level != nil { + level = opts.Level.Level() + } else { + level = slog.LevelInfo + } + + return &CustomFormatHandler{ + w: w, + level: level, + } +} + +// Enabled 实现 slog.Handler 接口 +func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool { + return level >= h.level +} + +// Handle 实现 slog.Handler 接口,以自定义格式输出日志 +func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + + // 时间格式 + timeStr := record.Time.Format("2006-01-02 15:04:05.000") + + // 日志级别 + var levelStr string + switch { + case record.Level >= slog.LevelError: + levelStr = "ERROR" + case record.Level >= slog.LevelWarn: + levelStr = "WARN " + case record.Level >= slog.LevelInfo: + levelStr = "INFO " + default: + levelStr = "DEBUG" + } + + // 构建日志行 + logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message) + + // 处理其他属性 + var attrs []string + record.Attrs(func(attr slog.Attr) bool { + attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value)) + return true + }) + + if len(attrs) > 0 { + logLine += " " + strings.Join(attrs, " ") + } + + // 写入日志 + _, err := fmt.Fprintln(h.w, logLine) + return err +} + +// WithAttrs 实现 slog.Handler 接口 +func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + h2 := *h + h2.attrs = append(h.attrs[:], attrs...) + return &h2 +} + +// WithGroup 实现 slog.Handler 接口 +func (h *CustomFormatHandler) WithGroup(name string) slog.Handler { + h2 := *h + h2.groups = append(h.groups[:], name) + return &h2 +} + +// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器 +type MultiHandler struct { + handlers []slog.Handler +} + +// Enabled 实现 slog.Handler 接口 +func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { + // 如果任何一个处理器启用了该级别,则返回 true + for _, handler := range h.handlers { + if handler.Enabled(ctx, level) { + return true + } + } + return false +} + +// Handle 实现 slog.Handler 接口 +func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { + // 将记录发送到所有处理器 + for _, handler := range h.handlers { + if handler.Enabled(ctx, record.Level) { + if err := handler.Handle(ctx, record); err != nil { + return err + } + } + } + return nil +} + +// WithAttrs 实现 slog.Handler 接口 +func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + newHandlers := make([]slog.Handler, len(h.handlers)) + for i, handler := range h.handlers { + newHandlers[i] = handler.WithAttrs(attrs) + } + return &MultiHandler{handlers: newHandlers} +} + +// WithGroup 实现 slog.Handler 接口 +func (h *MultiHandler) WithGroup(name string) slog.Handler { + newHandlers := make([]slog.Handler, len(h.handlers)) + for i, handler := range h.handlers { + newHandlers[i] = handler.WithGroup(name) + } + return &MultiHandler{handlers: newHandlers} +} + +// SetupLogger 设置日志输出 +func SetupLogger(logLevel string, logFile string) error { + // 创建标准错误输出的处理器,使用自定义格式 + stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ + Level: logLevelMap[logLevel], + }) + + // 如果没有指定日志文件,则仅使用标准错误处理器 + if logFile == "" { + slog.SetDefault(slog.New(stdHandler)) + return nil + } + + // 确保日志文件所在目录存在 + logDir := filepath.Dir(logFile) + if err := os.MkdirAll(logDir, 0755); err != nil { + return err + } + + // 打开日志文件,如果不存在则创建,追加写入模式 + file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return err + } + + // 创建文件输出的处理器,使用自定义格式 + fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{ + Level: logLevelMap[logLevel], + }) + + // 创建多输出处理器 + multiHandler := &MultiHandler{ + handlers: []slog.Handler{stdHandler, fileHandler}, + } + + // 设置全局日志处理器 + slog.SetDefault(slog.New(multiHandler)) + return nil +} + +// InitDefaultLogger 初始化默认日志处理器 +func InitDefaultLogger(level slog.Level) { + slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }))) +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 87bb9f9..62f273b 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -167,4 +167,3 @@ func TestGetSessionName(t *testing.T) { }) } } - diff --git a/tools/main.go b/tools/main.go index fd9c08c..6721a20 100644 --- a/tools/main.go +++ b/tools/main.go @@ -1,30 +1,30 @@ -package main - -import ( - "context" - "os" - "os/signal" - "syscall" - - "github.com/ossrs/go-oryx-lib/logger" - "github.com/ossrs/srs-bench/gb28181" -) - -func main() { - ctx := context.Background() - - var conf interface{} - conf = gb28181.Parse(ctx) - - ctx, cancel := context.WithCancel(ctx) - go func() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) - for sig := range sigs { - logger.Wf(ctx, "Quit for signal %v", sig) - cancel() - } - }() - - gb28181.Run(ctx, conf) -} +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/ossrs/go-oryx-lib/logger" + "github.com/ossrs/srs-bench/gb28181" +) + +func main() { + ctx := context.Background() + + var conf interface{} + conf = gb28181.Parse(ctx) + + ctx, cancel := context.WithCancel(ctx) + go func() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + for sig := range sigs { + logger.Wf(ctx, "Quit for signal %v", sig) + cancel() + } + }() + + gb28181.Run(ctx, conf) +}