This commit is contained in:
haibo.chen
2025-10-15 10:05:52 +08:00
parent d9709f61a5
commit 156f07644d
18 changed files with 1222 additions and 1227 deletions

View File

@ -133,7 +133,7 @@ func TestLoadConfigInvalid(t *testing.T) {
func TestGetLocalIP(t *testing.T) { func TestGetLocalIP(t *testing.T) {
ip, err := GetLocalIP() ip, err := GetLocalIP()
// 在某些环境下可能没有网络接口,所以允许返回错误 // 在某些环境下可能没有网络接口,所以允许返回错误
if err != nil { if err != nil {
t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err) t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err)
@ -152,4 +152,3 @@ func TestGetLocalIP(t *testing.T) {
t.Logf("Local IP: %s", ip) t.Logf("Local IP: %s", ip)
} }

View File

@ -1,152 +1,152 @@
package db package db
import ( import (
"database/sql" "database/sql"
"sync" "sync"
"github.com/ossrs/srs-sip/pkg/models" "github.com/ossrs/srs-sip/pkg/models"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
var ( var (
instance *MediaServerDB instance *MediaServerDB
once sync.Once once sync.Once
) )
type MediaServerDB struct { type MediaServerDB struct {
models.MediaServerResponse models.MediaServerResponse
db *sql.DB db *sql.DB
} }
// GetInstance 返回 MediaServerDB 的单例实例 // GetInstance 返回 MediaServerDB 的单例实例
func GetInstance(dbPath string) (*MediaServerDB, error) { func GetInstance(dbPath string) (*MediaServerDB, error) {
var err error var err error
once.Do(func() { once.Do(func() {
instance, err = NewMediaServerDB(dbPath) instance, err = NewMediaServerDB(dbPath)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return instance, nil return instance, nil
} }
func NewMediaServerDB(dbPath string) (*MediaServerDB, error) { func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 创建媒体服务器表 // 创建媒体服务器表
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE IF NOT EXISTS media_servers ( CREATE TABLE IF NOT EXISTS media_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL, type TEXT NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
ip TEXT NOT NULL, ip TEXT NOT NULL,
port INTEGER NOT NULL, port INTEGER NOT NULL,
username TEXT, username TEXT,
password TEXT, password TEXT,
secret TEXT, secret TEXT,
is_default INTEGER NOT NULL DEFAULT 0, is_default INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP created_at DATETIME DEFAULT CURRENT_TIMESTAMP
) )
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &MediaServerDB{db: db}, nil return &MediaServerDB{db: db}, nil
} }
// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器 // GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) { func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := m.db.QueryRow(` err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE name = ? AND ip = ? FROM media_servers WHERE name = ? AND ip = ?
`, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ms, nil return &ms, nil
} }
func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
_, err := m.db.Exec(` _, err := m.db.Exec(`
INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default) INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, name, serverType, ip, port, username, password, secret, isDefault) `, name, serverType, ip, port, username, password, secret, isDefault)
return err return err
} }
// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新) // AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
// 检查是否已存在 // 检查是否已存在
existing, err := m.GetMediaServerByNameAndIP(name, ip) existing, err := m.GetMediaServerByNameAndIP(name, ip)
if err == nil && existing != nil { if err == nil && existing != nil {
// 已存在,更新记录 // 已存在,更新记录
_, err = m.db.Exec(` _, err = m.db.Exec(`
UPDATE media_servers UPDATE media_servers
SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ? SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
WHERE name = ? AND ip = ? WHERE name = ? AND ip = ?
`, serverType, port, username, password, secret, isDefault, name, ip) `, serverType, port, username, password, secret, isDefault, name, ip)
return err return err
} }
// 不存在,插入新记录 // 不存在,插入新记录
return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault) return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
} }
func (m *MediaServerDB) DeleteMediaServer(id int) error { func (m *MediaServerDB) DeleteMediaServer(id int) error {
_, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id) _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
return err return err
} }
func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) { func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := m.db.QueryRow(` err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE id = ? FROM media_servers WHERE id = ?
`, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ms, nil return &ms, nil
} }
func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) { func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
rows, err := m.db.Query(` rows, err := m.db.Query(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers ORDER BY created_at DESC FROM media_servers ORDER BY created_at DESC
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var servers []models.MediaServerResponse var servers []models.MediaServerResponse
for rows.Next() { for rows.Next() {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, ms) servers = append(servers, ms)
} }
return servers, nil return servers, nil
} }
func (m *MediaServerDB) SetDefaultMediaServer(id int) error { func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
// 先将所有服务器设置为非默认 // 先将所有服务器设置为非默认
if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil { if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
return err return err
} }
// 将指定ID的服务器设置为默认 // 将指定ID的服务器设置为默认
_, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id) _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
return err return err
} }
func (m *MediaServerDB) Close() error { func (m *MediaServerDB) Close() error {
return m.db.Close() return m.db.Close()
} }

View File

@ -1,59 +1,59 @@
package media package media
import ( import (
"context" "context"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
) )
type Zlm struct { type Zlm struct {
Ctx context.Context Ctx context.Context
Schema string // The schema of ZLM, eg: http Schema string // The schema of ZLM, eg: http
Addr string // The address of ZLM, eg: localhost:8085 Addr string // The address of ZLM, eg: localhost:8085
Secret string // The secret of ZLM, eg: ZLMediaKit_secret Secret string // The secret of ZLM, eg: ZLMediaKit_secret
} }
// /index/api/openRtpServer // /index/api/openRtpServer
// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2 // secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
func (z *Zlm) Publish(id, ssrc string) (int, error) { func (z *Zlm) Publish(id, ssrc string) (int, error) {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
Port int `json:"port"` Port int `json:"port"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
return 0, errors.Wrapf(err, "gb/v1/publish") return 0, errors.Wrapf(err, "gb/v1/publish")
} }
return res.Port, nil return res.Port, nil
} }
// /index/api/closeRtpServer // /index/api/closeRtpServer
func (z *Zlm) Unpublish(id string) error { func (z *Zlm) Unpublish(id string) error {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return errors.Wrapf(err, "gb/v1/publish") return errors.Wrapf(err, "gb/v1/publish")
} }
return nil return nil
} }
// /index/api/getMediaList // /index/api/getMediaList
func (z *Zlm) GetStreamStatus(id string) (bool, error) { func (z *Zlm) GetStreamStatus(id string) (bool, error) {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return false, errors.Wrapf(err, "gb/v1/publish") return false, errors.Wrapf(err, "gb/v1/publish")
} }
return res.Code == 0, nil return res.Code == 0, nil
} }
func (z *Zlm) GetAddr() string { func (z *Zlm) GetAddr() string {
return z.Addr return z.Addr
} }
func (z *Zlm) GetWebRTCAddr(id string) string { func (z *Zlm) GetWebRTCAddr(id string) string {
return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play" return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
} }

View File

@ -1,106 +1,106 @@
package models package models
import "encoding/xml" import "encoding/xml"
type Record struct { type Record struct {
DeviceID string `xml:"DeviceID" json:"device_id"` DeviceID string `xml:"DeviceID" json:"device_id"`
Name string `xml:"Name" json:"name"` Name string `xml:"Name" json:"name"`
FilePath string `xml:"FilePath" json:"file_path"` FilePath string `xml:"FilePath" json:"file_path"`
Address string `xml:"Address" json:"address"` Address string `xml:"Address" json:"address"`
StartTime string `xml:"StartTime" json:"start_time"` StartTime string `xml:"StartTime" json:"start_time"`
EndTime string `xml:"EndTime" json:"end_time"` EndTime string `xml:"EndTime" json:"end_time"`
Secrecy int `xml:"Secrecy" json:"secrecy"` Secrecy int `xml:"Secrecy" json:"secrecy"`
Type string `xml:"Type" json:"type"` Type string `xml:"Type" json:"type"`
} }
// Example XML structure for channel info: // Example XML structure for channel info:
// //
// <Item> // <Item>
// <DeviceID>34020000001320000002</DeviceID> // <DeviceID>34020000001320000002</DeviceID>
// <Name>209</Name> // <Name>209</Name>
// <Manufacturer>UNIVIEW</Manufacturer> // <Manufacturer>UNIVIEW</Manufacturer>
// <Model>HIC6622-IR@X33-VF</Model> // <Model>HIC6622-IR@X33-VF</Model>
// <Owner>IPC-B2202.7.11.230222</Owner> // <Owner>IPC-B2202.7.11.230222</Owner>
// <CivilCode>CivilCode</CivilCode> // <CivilCode>CivilCode</CivilCode>
// <Address>Address</Address> // <Address>Address</Address>
// <Parental>1</Parental> // <Parental>1</Parental>
// <ParentID>75015310072008100002</ParentID> // <ParentID>75015310072008100002</ParentID>
// <SafetyWay>0</SafetyWay> // <SafetyWay>0</SafetyWay>
// <RegisterWay>1</RegisterWay> // <RegisterWay>1</RegisterWay>
// <Secrecy>0</Secrecy> // <Secrecy>0</Secrecy>
// <Status>ON</Status> // <Status>ON</Status>
// <Longitude>0.0000000</Longitude> // <Longitude>0.0000000</Longitude>
// <Latitude>0.0000000</Latitude> // <Latitude>0.0000000</Latitude>
// <Info> // <Info>
// <PTZType>1</PTZType> // <PTZType>1</PTZType>
// <Resolution>6/4/2</Resolution> // <Resolution>6/4/2</Resolution>
// <DownloadSpeed>0</DownloadSpeed> // <DownloadSpeed>0</DownloadSpeed>
// </Info> // </Info>
// </Item> // </Item>
type ChannelInfo struct { type ChannelInfo struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
ParentID string `json:"parent_id"` ParentID string `json:"parent_id"`
Name string `json:"name"` Name string `json:"name"`
Manufacturer string `json:"manufacturer"` Manufacturer string `json:"manufacturer"`
Model string `json:"model"` Model string `json:"model"`
Owner string `json:"owner"` Owner string `json:"owner"`
CivilCode string `json:"civil_code"` CivilCode string `json:"civil_code"`
Address string `json:"address"` Address string `json:"address"`
Port int `json:"port"` Port int `json:"port"`
Parental int `json:"parental"` Parental int `json:"parental"`
SafetyWay int `json:"safety_way"` SafetyWay int `json:"safety_way"`
RegisterWay int `json:"register_way"` RegisterWay int `json:"register_way"`
Secrecy int `json:"secrecy"` Secrecy int `json:"secrecy"`
IPAddress string `json:"ip_address"` IPAddress string `json:"ip_address"`
Status ChannelStatus `json:"status"` Status ChannelStatus `json:"status"`
Longitude float64 `json:"longitude"` Longitude float64 `json:"longitude"`
Latitude float64 `json:"latitude"` Latitude float64 `json:"latitude"`
Info struct { Info struct {
PTZType int `json:"ptz_type"` PTZType int `json:"ptz_type"`
Resolution string `json:"resolution"` Resolution string `json:"resolution"`
DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8 DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
} `json:"info"` } `json:"info"`
// Custom fields // Custom fields
Ssrc string `json:"ssrc"` Ssrc string `json:"ssrc"`
} }
type ChannelStatus string type ChannelStatus string
// BasicParam // BasicParam
// <! -- 基本参数配置(可选)--> // <! -- 基本参数配置(可选)-->
// <elementname="BasicParam"minOccurs="0"> // <elementname="BasicParam"minOccurs="0">
// <complexType> // <complexType>
// <sequence> // <sequence>
// <! -- 设备名称(可选)--> // <! -- 设备名称(可选)-->
// <elementname="Name"type="string" minOccurs="0"/> // <elementname="Name"type="string" minOccurs="0"/>
// <! -- 注册过期时间(可选)--> // <! -- 注册过期时间(可选)-->
// <elementname="Expiration"type="integer" minOccurs="0"/> // <elementname="Expiration"type="integer" minOccurs="0"/>
// <! -- 心跳间隔时间(可选)--> // <! -- 心跳间隔时间(可选)-->
// <elementname="HeartBeatInterval"type="integer" minOccurs="0"/> // <elementname="HeartBeatInterval"type="integer" minOccurs="0"/>
// <! -- 心跳超时次数(可选)--> // <! -- 心跳超时次数(可选)-->
// <elementname="HeartBeatCount"type="integer" minOccurs="0"/> // <elementname="HeartBeatCount"type="integer" minOccurs="0"/>
// </sequence> // </sequence>
// </complexType> // </complexType>
type BasicParam struct { type BasicParam struct {
Name string `xml:"Name"` Name string `xml:"Name"`
Expiration int `xml:"Expiration"` Expiration int `xml:"Expiration"`
HeartBeatInterval int `xml:"HeartBeatInterval"` HeartBeatInterval int `xml:"HeartBeatInterval"`
HeartBeatCount int `xml:"HeartBeatCount"` HeartBeatCount int `xml:"HeartBeatCount"`
} }
type XmlMessageInfo struct { type XmlMessageInfo struct {
XMLName xml.Name XMLName xml.Name
CmdType string CmdType string
SN int SN int
DeviceID string DeviceID string
DeviceName string DeviceName string
Manufacturer string Manufacturer string
Model string Model string
Channel string Channel string
DeviceList []ChannelInfo `xml:"DeviceList>Item"` DeviceList []ChannelInfo `xml:"DeviceList>Item"`
RecordList []*Record `xml:"RecordList>Item"` RecordList []*Record `xml:"RecordList>Item"`
BasicParam BasicParam `xml:"BasicParam"` BasicParam BasicParam `xml:"BasicParam"`
SumNum int SumNum int
} }

View File

@ -1,80 +1,80 @@
package models package models
type BaseRequest struct { type BaseRequest struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
} }
type InviteRequest struct { type InviteRequest struct {
BaseRequest BaseRequest
MediaServerId int `json:"media_server_id"` MediaServerId int `json:"media_server_id"`
PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
SubStream int `json:"sub_stream"` SubStream int `json:"sub_stream"`
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"` EndTime int64 `json:"end_time"`
} }
type InviteResponse struct { type InviteResponse struct {
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
URL string `json:"url"` URL string `json:"url"`
} }
type SessionRequest struct { type SessionRequest struct {
BaseRequest BaseRequest
URL string `json:"url"` URL string `json:"url"`
} }
type ByeRequest struct { type ByeRequest struct {
SessionRequest SessionRequest
} }
type PauseRequest struct { type PauseRequest struct {
SessionRequest SessionRequest
} }
type ResumeRequest struct { type ResumeRequest struct {
SessionRequest SessionRequest
} }
type SpeedRequest struct { type SpeedRequest struct {
SessionRequest SessionRequest
Speed float32 `json:"speed"` Speed float32 `json:"speed"`
} }
type PTZControlRequest struct { type PTZControlRequest struct {
BaseRequest BaseRequest
PTZ string `json:"ptz"` PTZ string `json:"ptz"`
Speed string `json:"speed"` Speed string `json:"speed"`
} }
type QueryRecordRequest struct { type QueryRecordRequest struct {
BaseRequest BaseRequest
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"` EndTime int64 `json:"end_time"`
} }
type MediaServer struct { type MediaServer struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
IP string `json:"ip"` IP string `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Secret string `json:"secret"` Secret string `json:"secret"`
IsDefault int `json:"is_default"` IsDefault int `json:"is_default"`
} }
type MediaServerRequest struct { type MediaServerRequest struct {
MediaServer MediaServer
} }
type MediaServerResponse struct { type MediaServerResponse struct {
MediaServer MediaServer
ID int `json:"id"` ID int `json:"id"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
} }
type CommonResponse struct { type CommonResponse struct {
Code int `json:"code"` Code int `json:"code"`
Data interface{} `json:"data"` Data interface{} `json:"data"`
} }

View File

@ -335,4 +335,3 @@ func TestCommonResponseWithDifferentDataTypes(t *testing.T) {
}) })
} }
} }

View File

@ -1,92 +1,92 @@
package service package service
import ( import (
"crypto/md5" "crypto/md5"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strings" "strings"
) )
// AuthInfo 存储解析后的认证信息 // AuthInfo 存储解析后的认证信息
type AuthInfo struct { type AuthInfo struct {
Username string Username string
Realm string Realm string
Nonce string Nonce string
URI string URI string
Response string Response string
Algorithm string Algorithm string
Method string Method string
} }
// GenerateNonce 生成随机 nonce 字符串 // GenerateNonce 生成随机 nonce 字符串
func GenerateNonce() string { func GenerateNonce() string {
b := make([]byte, 16) b := make([]byte, 16)
rand.Read(b) rand.Read(b)
return fmt.Sprintf("%x", b) return fmt.Sprintf("%x", b)
} }
// ParseAuthorization 解析 SIP Authorization 头 // ParseAuthorization 解析 SIP Authorization 头
// Authorization: Digest username="34020000001320000001",realm="3402000000", // Authorization: Digest username="34020000001320000001",realm="3402000000",
// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000", // nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5 // response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
func ParseAuthorization(auth string) *AuthInfo { func ParseAuthorization(auth string) *AuthInfo {
auth = strings.TrimPrefix(auth, "Digest ") auth = strings.TrimPrefix(auth, "Digest ")
parts := strings.Split(auth, ",") parts := strings.Split(auth, ",")
result := &AuthInfo{} result := &AuthInfo{}
for _, part := range parts { for _, part := range parts {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
if !strings.Contains(part, "=") { if !strings.Contains(part, "=") {
continue continue
} }
kv := strings.SplitN(part, "=", 2) kv := strings.SplitN(part, "=", 2)
key := strings.TrimSpace(kv[0]) key := strings.TrimSpace(kv[0])
value := strings.Trim(strings.TrimSpace(kv[1]), "\"") value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
switch key { switch key {
case "username": case "username":
result.Username = value result.Username = value
case "realm": case "realm":
result.Realm = value result.Realm = value
case "nonce": case "nonce":
result.Nonce = value result.Nonce = value
case "uri": case "uri":
result.URI = value result.URI = value
case "response": case "response":
result.Response = value result.Response = value
case "algorithm": case "algorithm":
result.Algorithm = value result.Algorithm = value
} }
} }
return result return result
} }
// ValidateAuth 验证 SIP 认证信息 // ValidateAuth 验证 SIP 认证信息
func ValidateAuth(authInfo *AuthInfo, password string) bool { func ValidateAuth(authInfo *AuthInfo, password string) bool {
if authInfo == nil { if authInfo == nil {
return false return false
} }
// 默认方法为 REGISTER // 默认方法为 REGISTER
method := "REGISTER" method := "REGISTER"
if authInfo.Method != "" { if authInfo.Method != "" {
method = authInfo.Method method = authInfo.Method
} }
// 计算 MD5 哈希 // 计算 MD5 哈希
ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password) ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
ha2 := md5Hex(method + ":" + authInfo.URI) ha2 := md5Hex(method + ":" + authInfo.URI)
correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2) correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
return authInfo.Response == correctResponse return authInfo.Response == correctResponse
} }
// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串 // md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
func md5Hex(s string) string { func md5Hex(s string) string {
hash := md5.New() hash := md5.New()
hash.Write([]byte(s)) hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }

View File

@ -343,4 +343,3 @@ func TestParseAuthorizationQuotedValues(t *testing.T) {
t.Logf("Realm value: '%s'", result.Realm) t.Logf("Realm value: '%s'", result.Realm)
} }
} }

View File

@ -1,17 +1,17 @@
package service package service
import ( import (
"context" "context"
"github.com/emiago/sipgo" "github.com/emiago/sipgo"
"github.com/ossrs/srs-sip/pkg/config" "github.com/ossrs/srs-sip/pkg/config"
) )
type Cascade struct { type Cascade struct {
ua *sipgo.UserAgent ua *sipgo.UserAgent
sipCli *sipgo.Client sipCli *sipgo.Client
sipSvr *sipgo.Server sipSvr *sipgo.Server
ctx context.Context ctx context.Context
conf *config.MainConfig conf *config.MainConfig
} }

View File

@ -1,171 +1,171 @@
package service package service
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/srs-sip/pkg/models" "github.com/ossrs/srs-sip/pkg/models"
"github.com/ossrs/srs-sip/pkg/service/stack" "github.com/ossrs/srs-sip/pkg/service/stack"
"golang.org/x/net/html/charset" "golang.org/x/net/html/charset"
) )
const GB28181_ID_LENGTH = 20 const GB28181_ID_LENGTH = 20
func (s *UAS) isSameIP(addr1, addr2 string) bool { func (s *UAS) isSameIP(addr1, addr2 string) bool {
ip1, _, err1 := net.SplitHostPort(addr1) ip1, _, err1 := net.SplitHostPort(addr1)
ip2, _, err2 := net.SplitHostPort(addr2) ip2, _, err2 := net.SplitHostPort(addr2)
// 如果解析出错,回退到完整字符串比较 // 如果解析出错,回退到完整字符串比较
if err1 != nil || err2 != nil { if err1 != nil || err2 != nil {
return addr1 == addr2 return addr1 == addr2
} }
return ip1 == ip2 return ip1 == ip2
} }
func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User id := req.From().Address.User
if len(id) != GB28181_ID_LENGTH { if len(id) != GB28181_ID_LENGTH {
slog.Error("invalid device ID") slog.Error("invalid device ID")
return return
} }
slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String())) slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
if s.conf.GB28181.Auth.Enable { if s.conf.GB28181.Auth.Enable {
// Check if Authorization header exists // Check if Authorization header exists
authHeader := req.GetHeaders("Authorization") authHeader := req.GetHeaders("Authorization")
// If no Authorization header, send 401 response to request authentication // If no Authorization header, send 401 response to request authentication
if len(authHeader) == 0 { if len(authHeader) == 0 {
nonce := GenerateNonce() nonce := GenerateNonce()
resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm) resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
_ = tx.Respond(resp) _ = tx.Respond(resp)
return return
} }
// Validate Authorization // Validate Authorization
authInfo := ParseAuthorization(authHeader[0].Value()) authInfo := ParseAuthorization(authHeader[0].Value())
if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) { if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
slog.Error("auth failed", "device_id", id, "source", req.Source()) slog.Error("auth failed", "device_id", id, "source", req.Source())
s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx) s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
return return
} }
} }
isUnregister := false isUnregister := false
if exps := req.GetHeaders("Expires"); len(exps) > 0 { if exps := req.GetHeaders("Expires"); len(exps) > 0 {
exp := exps[0] exp := exps[0]
expSec, err := strconv.ParseInt(exp.Value(), 10, 32) expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
if err != nil { if err != nil {
slog.Error("parse expires header error", "error", err.Error()) slog.Error("parse expires header error", "error", err.Error())
return return
} }
if expSec == 0 { if expSec == 0 {
isUnregister = true isUnregister = true
} }
} else { } else {
slog.Error("empty expires header") slog.Error("empty expires header")
return return
} }
if isUnregister { if isUnregister {
DM.RemoveDevice(id) DM.RemoveDevice(id)
slog.Warn("Device unregistered", "device_id", id) slog.Warn("Device unregistered", "device_id", id)
return return
} else { } else {
if d, ok := DM.GetDevice(id); !ok { if d, ok := DM.GetDevice(id); !ok {
DM.AddDevice(id, &DeviceInfo{ DM.AddDevice(id, &DeviceInfo{
DeviceID: id, DeviceID: id,
SourceAddr: req.Source(), SourceAddr: req.Source(),
NetworkType: req.Transport(), NetworkType: req.Transport(),
}) })
s.respondRegister(req, http.StatusOK, "OK", tx) s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source())) slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
go s.ConfigDownload(id) go s.ConfigDownload(id)
go s.Catalog(id) go s.Catalog(id)
} else { } else {
if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) { if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source()) slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
// TODO: 如果ID重复应采用虚拟ID // TODO: 如果ID重复应采用虚拟ID
s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx) s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
} else { } else {
d.SourceAddr = req.Source() d.SourceAddr = req.Source()
d.NetworkType = req.Transport() d.NetworkType = req.Transport()
DM.UpdateDevice(id, d) DM.UpdateDevice(id, d)
s.respondRegister(req, http.StatusOK, "OK", tx) s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source())) slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
} }
} }
} }
} }
func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) { func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
res := stack.NewRegisterResponse(req, code, reason) res := stack.NewRegisterResponse(req, code, reason)
_ = tx.Respond(res) _ = tx.Respond(res)
} }
func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User id := req.From().Address.User
if len(id) != 20 { if len(id) != 20 {
slog.Error("invalid device ID", "request", req.String()) slog.Error("invalid device ID", "request", req.String())
} }
slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String())) slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
temp := &models.XmlMessageInfo{} temp := &models.XmlMessageInfo{}
decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body()))) decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
decoder.CharsetReader = charset.NewReaderLabel decoder.CharsetReader = charset.NewReaderLabel
if err := decoder.Decode(temp); err != nil { if err := decoder.Decode(temp); err != nil {
slog.Error("decode message error", "error", err.Error(), "message", req.Body()) slog.Error("decode message error", "error", err.Error(), "message", req.Body())
} }
slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source())) slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
var body string var body string
switch temp.CmdType { switch temp.CmdType {
case "Keepalive": case "Keepalive":
if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online { if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
// 更新设备心跳时间 // 更新设备心跳时间
DM.UpdateDeviceHeartbeat(temp.DeviceID) DM.UpdateDeviceHeartbeat(temp.DeviceID)
} else { } else {
tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)) tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
return return
} }
case "SensorCatalog": // 兼容宇视,非国标 case "SensorCatalog": // 兼容宇视,非国标
case "Catalog": case "Catalog":
DM.UpdateChannels(temp.DeviceID, temp.DeviceList...) DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
//go s.AutoInvite(temp.DeviceID, temp.DeviceList...) //go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
case "ConfigDownload": case "ConfigDownload":
DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam) DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
case "Alarm": case "Alarm":
slog.Info("Alarm") slog.Info("Alarm")
case "RecordInfo": case "RecordInfo":
// 从 recordQueryResults 中获取对应通道的结果通道 // 从 recordQueryResults 中获取对应通道的结果通道
if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok { if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
// 发送查询结果 // 发送查询结果
resultChan := ch.(chan *models.XmlMessageInfo) resultChan := ch.(chan *models.XmlMessageInfo)
resultChan <- temp resultChan <- temp
} }
default: default:
slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType) slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil) response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
tx.Respond(response) tx.Respond(response)
return return
} }
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body))) tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
} }
func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String())) slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil)) tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
} }

View File

@ -1,81 +1,81 @@
package service package service
import "fmt" import "fmt"
var ( var (
ptzCmdMap = map[string]uint8{ ptzCmdMap = map[string]uint8{
"stop": 0, "stop": 0,
"right": 1, "right": 1,
"left": 2, "left": 2,
"down": 4, "down": 4,
"downright": 5, "downright": 5,
"downleft": 6, "downleft": 6,
"up": 8, "up": 8,
"upright": 9, "upright": 9,
"upleft": 10, "upleft": 10,
"zoomin": 16, "zoomin": 16,
"zoomout": 32, "zoomout": 32,
} }
ptzSpeedMap = map[string]uint8{ ptzSpeedMap = map[string]uint8{
"1": 25, "1": 25,
"2": 50, "2": 50,
"3": 75, "3": 75,
"4": 100, "4": 100,
"5": 125, "5": 125,
"6": 150, "6": 150,
"7": 175, "7": 175,
"8": 200, "8": 200,
"9": 225, "9": 225,
"10": 255, "10": 255,
} }
defaultSpeed uint8 = 125 defaultSpeed uint8 = 125
) )
func getPTZSpeed(speed string) uint8 { func getPTZSpeed(speed string) uint8 {
if v, ok := ptzSpeedMap[speed]; ok { if v, ok := ptzSpeedMap[speed]; ok {
return v return v
} }
return defaultSpeed return defaultSpeed
} }
func toPTZCmd(cmdName, speed string) (string, error) { func toPTZCmd(cmdName, speed string) (string, error) {
cmdCode, ok := ptzCmdMap[cmdName] cmdCode, ok := ptzCmdMap[cmdName]
if !ok { if !ok {
return "", fmt.Errorf("invalid ptz command: %q", cmdName) return "", fmt.Errorf("invalid ptz command: %q", cmdName)
} }
speedValue := getPTZSpeed(speed) speedValue := getPTZSpeed(speed)
var horizontalSpeed, verticalSpeed, zSpeed uint8 var horizontalSpeed, verticalSpeed, zSpeed uint8
switch cmdName { switch cmdName {
case "left", "right": case "left", "right":
horizontalSpeed = speedValue horizontalSpeed = speedValue
verticalSpeed = 0 verticalSpeed = 0
case "up", "down": case "up", "down":
verticalSpeed = speedValue verticalSpeed = speedValue
horizontalSpeed = 0 horizontalSpeed = 0
case "upleft", "upright", "downleft", "downright": case "upleft", "upright", "downleft", "downright":
verticalSpeed = speedValue verticalSpeed = speedValue
horizontalSpeed = speedValue horizontalSpeed = speedValue
case "zoomin", "zoomout": case "zoomin", "zoomout":
zSpeed = speedValue << 4 // zoom速度在高4位 zSpeed = speedValue << 4 // zoom速度在高4位
default: default:
horizontalSpeed = 0 horizontalSpeed = 0
verticalSpeed = 0 verticalSpeed = 0
zSpeed = 0 zSpeed = 0
} }
sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed) sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
checksum := uint8(sum % 256) checksum := uint8(sum % 256)
return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X", return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
cmdCode, cmdCode,
horizontalSpeed, horizontalSpeed,
verticalSpeed, verticalSpeed,
zSpeed, zSpeed,
checksum, checksum,
), nil ), nil
} }

View File

@ -141,7 +141,7 @@ func TestToPTZCmdSpecificCases(t *testing.T) {
func TestToPTZCmdWithDifferentSpeeds(t *testing.T) { func TestToPTZCmdWithDifferentSpeeds(t *testing.T) {
speeds := []string{"1", "5", "10"} speeds := []string{"1", "5", "10"}
for _, speed := range speeds { for _, speed := range speeds {
t.Run("Right with speed "+speed, func(t *testing.T) { t.Run("Right with speed "+speed, func(t *testing.T) {
result, err := toPTZCmd("right", speed) result, err := toPTZCmd("right", speed)
@ -196,4 +196,3 @@ func TestPTZSpeedMap(t *testing.T) {
}) })
} }
} }

View File

@ -1,68 +1,68 @@
package stack package stack
import ( import (
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
) )
type OutboundConfig struct { type OutboundConfig struct {
Transport string Transport string
Via string Via string
From string From string
To string To string
} }
func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) { func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
if len(conf.From) != 20 || len(conf.To) != 20 { if len(conf.From) != 20 || len(conf.To) != 20 {
return nil, errors.Errorf("From or To length is not 20") return nil, errors.Errorf("From or To length is not 20")
} }
dest := conf.Via dest := conf.Via
to := sip.Uri{User: conf.To, Host: conf.To[:10]} to := sip.Uri{User: conf.To, Host: conf.To[:10]}
from := &sip.Uri{User: conf.From, Host: conf.From[:10]} from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()} fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
fromHeader.Params.Add("tag", sip.GenerateTagN(16)) fromHeader.Params.Add("tag", sip.GenerateTagN(16))
req := sip.NewRequest(method, to) req := sip.NewRequest(method, to)
req.AppendHeader(fromHeader) req.AppendHeader(fromHeader)
req.AppendHeader(&sip.ToHeader{Address: to}) req.AppendHeader(&sip.ToHeader{Address: to})
req.AppendHeader(&sip.ContactHeader{Address: *from}) req.AppendHeader(&sip.ContactHeader{Address: *from})
req.AppendHeader(sip.NewHeader("Max-Forwards", "70")) req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
req.SetBody(body) req.SetBody(body)
req.SetDestination(dest) req.SetDestination(dest)
req.SetTransport(conf.Transport) req.SetTransport(conf.Transport)
return req, nil return req, nil
} }
func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) { func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.REGISTER, nil, conf) req, err := NewRequest(sip.REGISTER, nil, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Expires", "3600")) req.AppendHeader(sip.NewHeader("Expires", "3600"))
return req, nil return req, nil
} }
func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) { func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.INVITE, body, conf) req, err := NewRequest(sip.INVITE, body, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
req.AppendHeader(sip.NewHeader("Subject", subject)) req.AppendHeader(sip.NewHeader("Subject", subject))
return req, nil return req, nil
} }
func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) { func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.MESSAGE, body, conf) req, err := NewRequest(sip.MESSAGE, body, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml")) req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
return req, nil return req, nil
} }

View File

@ -1,41 +1,41 @@
package stack package stack
import ( import (
"fmt" "fmt"
"time" "time"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
) )
const TIME_LAYOUT = "2024-01-01T00:00:00" const TIME_LAYOUT = "2024-01-01T00:00:00"
const EXPIRES_TIME = 3600 const EXPIRES_TIME = 3600
func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := sip.NewResponseFromRequest(req, code, reason, nil) resp := sip.NewResponseFromRequest(req, code, reason, nil)
newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()} newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
newTo.Params.Add("tag", sip.GenerateTagN(10)) newTo.Params.Add("tag", sip.GenerateTagN(10))
resp.ReplaceHeader(newTo) resp.ReplaceHeader(newTo)
resp.RemoveHeader("Allow") resp.RemoveHeader("Allow")
return resp return resp
} }
func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := newResponse(req, code, reason) resp := newResponse(req, code, reason)
expires := sip.ExpiresHeader(EXPIRES_TIME) expires := sip.ExpiresHeader(EXPIRES_TIME)
resp.AppendHeader(&expires) resp.AppendHeader(&expires)
resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT))) resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
return resp return resp
} }
func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response { func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
resp := newResponse(req, code, reason) resp := newResponse(req, code, reason)
resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce))) resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
return resp return resp
} }

View File

@ -1,122 +1,122 @@
package service package service
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"github.com/emiago/sipgo" "github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/srs-sip/pkg/config" "github.com/ossrs/srs-sip/pkg/config"
"github.com/ossrs/srs-sip/pkg/service/stack" "github.com/ossrs/srs-sip/pkg/service/stack"
) )
const ( const (
UserAgent = "SRS-SIP/1.0" UserAgent = "SRS-SIP/1.0"
) )
type UAC struct { type UAC struct {
*Cascade *Cascade
SN uint32 SN uint32
LocalIP string LocalIP string
} }
func NewUac() *UAC { func NewUac() *UAC {
ip, err := config.GetLocalIP() ip, err := config.GetLocalIP()
if err != nil { if err != nil {
slog.Error("get local ip failed", "error", err) slog.Error("get local ip failed", "error", err)
return nil return nil
} }
c := &UAC{ c := &UAC{
Cascade: &Cascade{}, Cascade: &Cascade{},
LocalIP: ip, LocalIP: ip,
} }
return c return c
} }
func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error { func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
var err error var err error
c.ctx = context.Background() c.ctx = context.Background()
c.conf = r0.(*config.MainConfig) c.conf = r0.(*config.MainConfig)
if agent == nil { if agent == nil {
ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent)) ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
if err != nil { if err != nil {
return err return err
} }
agent = ua agent = ua
} }
c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP)) c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
if err != nil { if err != nil {
return err return err
} }
c.sipSvr, err = sipgo.NewServer(agent) c.sipSvr, err = sipgo.NewServer(agent)
if err != nil { if err != nil {
return err return err
} }
c.sipSvr.OnInvite(c.onInvite) c.sipSvr.OnInvite(c.onInvite)
c.sipSvr.OnBye(c.onBye) c.sipSvr.OnBye(c.onBye)
c.sipSvr.OnMessage(c.onMessage) c.sipSvr.OnMessage(c.onMessage)
go c.doRegister() go c.doRegister()
return nil return nil
} }
func (c *UAC) Stop() { func (c *UAC) Stop() {
// TODO: 断开所有当前连接 // TODO: 断开所有当前连接
c.sipCli.Close() c.sipCli.Close()
c.sipSvr.Close() c.sipSvr.Close()
} }
func (c *UAC) doRegister() error { func (c *UAC) doRegister() error {
r, _ := stack.NewRegisterRequest(stack.OutboundConfig{ r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
From: "34020000001110000001", From: "34020000001110000001",
To: "34020000002000000001", To: "34020000002000000001",
Transport: "UDP", Transport: "UDP",
Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port), Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
}) })
tx, err := c.sipCli.TransactionRequest(c.ctx, r) tx, err := c.sipCli.TransactionRequest(c.ctx, r)
if err != nil { if err != nil {
return errors.Wrapf(err, "transaction request error") return errors.Wrapf(err, "transaction request error")
} }
rs, _ := c.getResponse(tx) rs, _ := c.getResponse(tx)
slog.Info("register response", "response", rs.String()) slog.Info("register response", "response", rs.String())
return nil return nil
} }
func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
switch req.Method { switch req.Method {
case "INVITE": case "INVITE":
c.onInvite(req, tx) c.onInvite(req, tx)
} }
} }
func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onInvite") slog.Debug("onInvite")
} }
func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onBye") slog.Debug("onBye")
} }
func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onMessage", "request", req.String()) slog.Debug("onMessage", "request", req.String())
} }
func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) { func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
select { select {
case <-tx.Done(): case <-tx.Done():
return nil, fmt.Errorf("transaction died") return nil, fmt.Errorf("transaction died")
case res := <-tx.Responses(): case res := <-tx.Responses():
return res, nil return res, nil
} }
} }

View File

@ -1,201 +1,201 @@
package utils package utils
import ( import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
) )
var logLevelMap = map[string]slog.Level{ var logLevelMap = map[string]slog.Level{
"debug": slog.LevelDebug, "debug": slog.LevelDebug,
"info": slog.LevelInfo, "info": slog.LevelInfo,
"warn": slog.LevelWarn, "warn": slog.LevelWarn,
"error": slog.LevelError, "error": slog.LevelError,
} }
// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志 // 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
type CustomFormatHandler struct { type CustomFormatHandler struct {
mu sync.Mutex mu sync.Mutex
w io.Writer w io.Writer
level slog.Level level slog.Level
attrs []slog.Attr attrs []slog.Attr
groups []string groups []string
} }
// NewCustomFormatHandler 创建一个新的自定义格式处理器 // NewCustomFormatHandler 创建一个新的自定义格式处理器
func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler { func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
if opts == nil { if opts == nil {
opts = &slog.HandlerOptions{} opts = &slog.HandlerOptions{}
} }
// 获取日志级别如果opts.Level是nil则默认为Info // 获取日志级别如果opts.Level是nil则默认为Info
var level slog.Level var level slog.Level
if opts.Level != nil { if opts.Level != nil {
level = opts.Level.Level() level = opts.Level.Level()
} else { } else {
level = slog.LevelInfo level = slog.LevelInfo
} }
return &CustomFormatHandler{ return &CustomFormatHandler{
w: w, w: w,
level: level, level: level,
} }
} }
// Enabled 实现 slog.Handler 接口 // Enabled 实现 slog.Handler 接口
func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool { func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= h.level return level >= h.level
} }
// Handle 实现 slog.Handler 接口,以自定义格式输出日志 // Handle 实现 slog.Handler 接口,以自定义格式输出日志
func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error { func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
// 时间格式 // 时间格式
timeStr := record.Time.Format("2006-01-02 15:04:05.000") timeStr := record.Time.Format("2006-01-02 15:04:05.000")
// 日志级别 // 日志级别
var levelStr string var levelStr string
switch { switch {
case record.Level >= slog.LevelError: case record.Level >= slog.LevelError:
levelStr = "ERROR" levelStr = "ERROR"
case record.Level >= slog.LevelWarn: case record.Level >= slog.LevelWarn:
levelStr = "WARN " levelStr = "WARN "
case record.Level >= slog.LevelInfo: case record.Level >= slog.LevelInfo:
levelStr = "INFO " levelStr = "INFO "
default: default:
levelStr = "DEBUG" levelStr = "DEBUG"
} }
// 构建日志行 // 构建日志行
logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message) logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
// 处理其他属性 // 处理其他属性
var attrs []string var attrs []string
record.Attrs(func(attr slog.Attr) bool { record.Attrs(func(attr slog.Attr) bool {
attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value)) attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
return true return true
}) })
if len(attrs) > 0 { if len(attrs) > 0 {
logLine += " " + strings.Join(attrs, " ") logLine += " " + strings.Join(attrs, " ")
} }
// 写入日志 // 写入日志
_, err := fmt.Fprintln(h.w, logLine) _, err := fmt.Fprintln(h.w, logLine)
return err return err
} }
// WithAttrs 实现 slog.Handler 接口 // WithAttrs 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler { func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
h2 := *h h2 := *h
h2.attrs = append(h.attrs[:], attrs...) h2.attrs = append(h.attrs[:], attrs...)
return &h2 return &h2
} }
// WithGroup 实现 slog.Handler 接口 // WithGroup 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithGroup(name string) slog.Handler { func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
h2 := *h h2 := *h
h2.groups = append(h.groups[:], name) h2.groups = append(h.groups[:], name)
return &h2 return &h2
} }
// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器 // MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
type MultiHandler struct { type MultiHandler struct {
handlers []slog.Handler handlers []slog.Handler
} }
// Enabled 实现 slog.Handler 接口 // Enabled 实现 slog.Handler 接口
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
// 如果任何一个处理器启用了该级别,则返回 true // 如果任何一个处理器启用了该级别,则返回 true
for _, handler := range h.handlers { for _, handler := range h.handlers {
if handler.Enabled(ctx, level) { if handler.Enabled(ctx, level) {
return true return true
} }
} }
return false return false
} }
// Handle 实现 slog.Handler 接口 // Handle 实现 slog.Handler 接口
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// 将记录发送到所有处理器 // 将记录发送到所有处理器
for _, handler := range h.handlers { for _, handler := range h.handlers {
if handler.Enabled(ctx, record.Level) { if handler.Enabled(ctx, record.Level) {
if err := handler.Handle(ctx, record); err != nil { if err := handler.Handle(ctx, record); err != nil {
return err return err
} }
} }
} }
return nil return nil
} }
// WithAttrs 实现 slog.Handler 接口 // WithAttrs 实现 slog.Handler 接口
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers)) newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers { for i, handler := range h.handlers {
newHandlers[i] = handler.WithAttrs(attrs) newHandlers[i] = handler.WithAttrs(attrs)
} }
return &MultiHandler{handlers: newHandlers} return &MultiHandler{handlers: newHandlers}
} }
// WithGroup 实现 slog.Handler 接口 // WithGroup 实现 slog.Handler 接口
func (h *MultiHandler) WithGroup(name string) slog.Handler { func (h *MultiHandler) WithGroup(name string) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers)) newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers { for i, handler := range h.handlers {
newHandlers[i] = handler.WithGroup(name) newHandlers[i] = handler.WithGroup(name)
} }
return &MultiHandler{handlers: newHandlers} return &MultiHandler{handlers: newHandlers}
} }
// SetupLogger 设置日志输出 // SetupLogger 设置日志输出
func SetupLogger(logLevel string, logFile string) error { func SetupLogger(logLevel string, logFile string) error {
// 创建标准错误输出的处理器,使用自定义格式 // 创建标准错误输出的处理器,使用自定义格式
stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevelMap[logLevel], Level: logLevelMap[logLevel],
}) })
// 如果没有指定日志文件,则仅使用标准错误处理器 // 如果没有指定日志文件,则仅使用标准错误处理器
if logFile == "" { if logFile == "" {
slog.SetDefault(slog.New(stdHandler)) slog.SetDefault(slog.New(stdHandler))
return nil return nil
} }
// 确保日志文件所在目录存在 // 确保日志文件所在目录存在
logDir := filepath.Dir(logFile) logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0755); err != nil { if err := os.MkdirAll(logDir, 0755); err != nil {
return err return err
} }
// 打开日志文件,如果不存在则创建,追加写入模式 // 打开日志文件,如果不存在则创建,追加写入模式
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil { if err != nil {
return err return err
} }
// 创建文件输出的处理器,使用自定义格式 // 创建文件输出的处理器,使用自定义格式
fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{ fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
Level: logLevelMap[logLevel], Level: logLevelMap[logLevel],
}) })
// 创建多输出处理器 // 创建多输出处理器
multiHandler := &MultiHandler{ multiHandler := &MultiHandler{
handlers: []slog.Handler{stdHandler, fileHandler}, handlers: []slog.Handler{stdHandler, fileHandler},
} }
// 设置全局日志处理器 // 设置全局日志处理器
slog.SetDefault(slog.New(multiHandler)) slog.SetDefault(slog.New(multiHandler))
return nil return nil
} }
// InitDefaultLogger 初始化默认日志处理器 // InitDefaultLogger 初始化默认日志处理器
func InitDefaultLogger(level slog.Level) { func InitDefaultLogger(level slog.Level) {
slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: level, Level: level,
}))) })))
} }

View File

@ -167,4 +167,3 @@ func TestGetSessionName(t *testing.T) {
}) })
} }
} }

View File

@ -1,30 +1,30 @@
package main package main
import ( import (
"context" "context"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/ossrs/go-oryx-lib/logger" "github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181" "github.com/ossrs/srs-bench/gb28181"
) )
func main() { func main() {
ctx := context.Background() ctx := context.Background()
var conf interface{} var conf interface{}
conf = gb28181.Parse(ctx) conf = gb28181.Parse(ctx)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
go func() { go func() {
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
for sig := range sigs { for sig := range sigs {
logger.Wf(ctx, "Quit for signal %v", sig) logger.Wf(ctx, "Quit for signal %v", sig)
cancel() cancel()
} }
}() }()
gb28181.Run(ctx, conf) gb28181.Run(ctx, conf)
} }