Compare commits

...

10 Commits

Author SHA1 Message Date
7e49007d16 Add Docker support and configuration for SRS SIP
- Created a new README_cross.md file with Docker build instructions.
- Updated srs.conf to include logging configuration options.
- Added docker-compose.yml to define the SRS SIP service with necessary ports and volume mappings.
- Introduced config.yaml for general and GB28181-specific configurations.
- Added initial srs.conf with settings for RTMP, HTTP API, and WebRTC support.
2026-01-13 11:28:56 +08:00
2aa65de911 security 2025-10-15 16:04:35 +08:00
35de09aeb6 ut 2025-10-15 15:35:41 +08:00
1178b974a1 codeql 2025-10-15 14:18:47 +08:00
156f07644d gofmt 2025-10-15 10:05:52 +08:00
d9709f61a5 codecov 2025-10-15 09:59:32 +08:00
59bc95ab21 update 2025-10-15 09:29:38 +08:00
4c7485f4ef unit test 2025-10-15 09:14:33 +08:00
b0fce4380f fix warn 2025-10-14 16:51:37 +08:00
a92d1624c5 ci 2025-10-14 16:48:21 +08:00
36 changed files with 8692 additions and 1229 deletions

22
.github/codeql/codeql-config.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "CodeQL Config"
# 指定要扫描的路径
paths:
- pkg
- main
# 排除不需要扫描的路径
paths-ignore:
- '**/*_test.go'
- 'html/**'
- 'objs/**'
- 'vendor/**'
# 使用的查询套件
queries:
- uses: security-extended
- uses: security-and-quality
# 禁用默认查询(如果只想使用自定义查询)
# disable-default-queries: true

82
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,82 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build-and-test:
name: Build and Test
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Cache Node modules
uses: actions/cache@v4
with:
path: html/NextGB/node_modules
key: ${{ runner.os }}-node-${{ hashFiles('html/NextGB/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Download Go dependencies
run: go mod download
- name: Build Go application
run: make build
- name: Run Go tests
run: go test -v ./...
- name: Run Go tests with coverage
run: go test ./pkg/... -coverprofile=coverage.out -covermode=atomic
- name: Display coverage report
run: go tool cover -func=coverage.out
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
- name: Install Vue dependencies
run: make vue-install
- name: Build Vue application
run: make vue-build
- name: Upload build artifacts
uses: actions/upload-artifact@v4
if: success()
with:
name: srs-sip-build
path: |
objs/srs-sip
html/NextGB/dist/
retention-days: 7

52
.github/workflows/codeql.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
# 每周一凌晨2点运行
- cron: '0 2 * * 1'
jobs:
analyze:
name: Analyze Go Code
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
cache: true
# 初始化 CodeQL
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/codeql-config.yml
# 自动构建
- name: Autobuild
uses: github/codeql-action/autobuild@v3
# 执行 CodeQL 分析
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@ -1,16 +1,29 @@
# 引入SRS
FROM ossrs/srs:v6.0.155 AS srs
FROM ossrs/srs:v6.0.184 AS srs
# 前端构建阶段
FROM node:20-slim AS frontend-builder
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY}
WORKDIR /app/frontend
COPY html/NextGB/package*.json ./
RUN npm install
# RUN npm config set registry http://mirrors.cloud.tencent.com/npm/ \
# && npm install
RUN npm install
COPY html/NextGB/ .
RUN npm run build
# 后端构建阶段
FROM golang:1.23 AS backend-builder
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY} \
GOPROXY=https://goproxy.cn,direct
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
@ -19,11 +32,20 @@ RUN CGO_ENABLED=0 GOOS=linux go build -o /app/srs-sip main/main.go
# 最终运行阶段
FROM ubuntu:22.04
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY}
WORKDIR /usr/local
# 设置时区
ENV TZ=Asia/Shanghai
RUN apt-get update && \
RUN sed -i \
-e 's@http://archive.ubuntu.com/ubuntu/@http://mirrors.ustc.edu.cn/ubuntu/@g' \
-e 's@http://security.ubuntu.com/ubuntu/@http://mirrors.ustc.edu.cn/ubuntu/@g' \
/etc/apt/sources.list && \
apt-get update && \
apt-get install -y ca-certificates tzdata supervisor && \
ln -fs /usr/share/zoneinfo/$TZ /etc/localtime && \
dpkg-reconfigure -f noninteractive tzdata && \
@ -60,7 +82,7 @@ stderr_logfile=/dev/stderr\n\
stderr_logfile_maxbytes=0\n\
\n\
[program:srs-sip]\n\
command=/usr/local/srs-sip/srs-sip\n\
command=/usr/local/srs-sip/srs-sip -c /usr/local/srs-sip/config.yaml\n\
directory=/usr/local/srs-sip\n\
autostart=true\n\
autorestart=true\n\
@ -71,4 +93,4 @@ stderr_logfile_maxbytes=0" > /etc/supervisor/conf.d/supervisord.conf
EXPOSE 1935 5060 8025 9000 5060/udp 8000/udp
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

4530
GBT+28181-2022.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,11 @@
# SRS-SIP
[![CI](https://github.com/ossrs/srs-sip/actions/workflows/ci.yml/badge.svg)](https://github.com/ossrs/srs-sip/actions/workflows/ci.yml)
[![CodeQL](https://github.com/ossrs/srs-sip/actions/workflows/codeql.yml/badge.svg)](https://github.com/ossrs/srs-sip/actions/workflows/codeql.yml)
[![codecov](https://codecov.io/gh/ossrs/srs-sip/branch/main/graph/badge.svg)](https://codecov.io/gh/ossrs/srs-sip)
[![Go Report Card](https://goreportcard.com/badge/github.com/ossrs/srs-sip)](https://goreportcard.com/report/github.com/ossrs/srs-sip)
[![License](https://img.shields.io/github/license/ossrs/srs-sip)](https://github.com/ossrs/srs-sip/blob/main/LICENSE)
## Usage
Pre-requisites:
@ -24,6 +30,29 @@ Run the program:
./objs/srs-sip -c conf/config.yaml
```
## Testing
Run all tests:
```bash
go test -v ./pkg/...
```
Run tests with coverage:
```bash
go test ./pkg/... -coverprofile=coverage.out
go tool cover -func=coverage.out
```
For more details, see [Testing Guide](docs/TESTING.md).
## Security
This project uses CodeQL for automated security scanning. For more information about security practices and how to report vulnerabilities, see [Security Guide](docs/SECURITY.md).
## Docker
Use docker
```
docker run -id -p 1985:1985 -p 5060:5060 -p 8025:8025 -p 9000:9000 -p 5060:5060/udp -p 8000:8000/udp --name srs-sip --env CANDIDATE=your_ip ossrs/srs-sip:alpha

4
README_cross.md Normal file
View File

@ -0,0 +1,4 @@
```bash
docker compose build --network host
```

View File

@ -3,6 +3,11 @@ max_connections 1000;
# For docker, please use docker logs to manage the logs of SRS.
# See https://docs.docker.com/config/containers/logging/
srs_log_tank console;
# srs_log_tank file;
# srs_log_file /var/log/srs/srs.log;
# ff_log_dir /var/log/srs;
daemon off;
disable_daemon_for_docker off;
http_api {
@ -57,4 +62,4 @@ vhost __defaultVhost__ {
rtc_to_rtmp on;
pli_for_rtmp 6.0;
}
}
}

30
docker-compose.yml Normal file
View File

@ -0,0 +1,30 @@
services:
srs-sip:
build:
context: .
network: host
args:
HTTP_PROXY: ${HTTP_PROXY:-}
NO_PROXY: "localhost,127.0.0.1,::1"
environment:
# CANDIDATE: ${CANDIDATE:-}
CANDIDATE: 192.168.2.184
TZ: "Asia/Shanghai"
volumes:
- ./run/conf/config.yaml:/usr/local/srs-sip/config.yaml:ro
- ./run/logs:/usr/local/srs-sip/logs
- ./run/srs/conf/srs.conf:/usr/local/srs/conf/srs.conf:ro
- ./run/srs/logs:/var/log/srs/
ports:
# SRS RTMP
- "1985:1985"
# SRS media ingest (GB28181 RTP/PS 等,取决于你的配置)
- "9000:9000"
# SRS WebRTC
- "8000:8000/udp"
# SIP
- "5060:5060"
- "5060:5060/udp"
# WebUI
- "8025:8025"
restart: unless-stopped

View File

@ -8,6 +8,7 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"strconv"
"strings"
"syscall"
@ -95,9 +96,42 @@ func main() {
return
}
// 首先检查原始路径是否包含 ".." 以防止路径遍历攻击
if strings.Contains(r.URL.Path, "..") {
slog.Warn("potential path traversal attempt detected", "path", r.URL.Path)
http.Error(w, "Invalid path", http.StatusBadRequest)
return
}
// 清理路径
cleanPath := path.Clean(r.URL.Path)
// 检查请求的文件是否存在
filePath := path.Join(conf.Http.Dir, r.URL.Path)
_, err := os.Stat(filePath)
filePath := filepath.Join(conf.Http.Dir, cleanPath)
// 确保最终路径在允许的目录内
absDir, err := filepath.Abs(conf.Http.Dir)
if err != nil {
slog.Error("failed to get absolute path of http dir", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
slog.Error("failed to get absolute path of file", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// 验证文件路径在允许的目录内
if !strings.HasPrefix(absFilePath, absDir) {
slog.Warn("path traversal attempt blocked", "requested", r.URL.Path, "resolved", absFilePath)
http.Error(w, "Access denied", http.StatusForbidden)
return
}
_, err = os.Stat(absFilePath)
if os.IsNotExist(err) {
// 如果文件不存在,返回 index.html
slog.Info("file not found, redirect to index", "path", r.URL.Path)

247
main/main_test.go Normal file
View File

@ -0,0 +1,247 @@
package main
import (
"path"
"path/filepath"
"strings"
"testing"
)
// TestPathTraversalPrevention 测试路径遍历防护
func TestPathTraversalPrevention(t *testing.T) {
baseDir := "/var/www/html"
tests := []struct {
name string
inputPath string
shouldFail bool
reason string
}{
{
name: "Normal file",
inputPath: "/index.html",
shouldFail: false,
reason: "Normal file access should be allowed",
},
{
name: "Subdirectory file",
inputPath: "/css/style.css",
shouldFail: false,
reason: "Subdirectory access should be allowed",
},
{
name: "Deep subdirectory",
inputPath: "/js/lib/jquery.min.js",
shouldFail: false,
reason: "Deep subdirectory access should be allowed",
},
{
name: "Parent directory traversal",
inputPath: "/../etc/passwd",
shouldFail: true,
reason: "Parent directory traversal should be blocked",
},
{
name: "Double parent traversal",
inputPath: "/../../etc/passwd",
shouldFail: true,
reason: "Double parent traversal should be blocked",
},
{
name: "Multiple parent traversal",
inputPath: "/../../../etc/passwd",
shouldFail: true,
reason: "Multiple parent traversal should be blocked",
},
{
name: "Mixed path with parent",
inputPath: "/css/../../etc/passwd",
shouldFail: true,
reason: "Mixed path with parent should be blocked",
},
{
name: "Dot slash path",
inputPath: "/./index.html",
shouldFail: false,
reason: "Dot slash should be cleaned but allowed",
},
{
name: "Complex traversal",
inputPath: "/css/../js/../../../etc/passwd",
shouldFail: true,
reason: "Complex traversal should be blocked",
},
{
name: "Root path",
inputPath: "/",
shouldFail: false,
reason: "Root path should be allowed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟修复后的路径验证逻辑
// 首先检查原始路径是否包含 ".."
containsDoubleDotInOriginal := strings.Contains(tt.inputPath, "..")
// 如果原始路径包含 "..",直接阻止
if containsDoubleDotInOriginal {
if !tt.shouldFail {
t.Errorf("%s: Path contains '..' but should be allowed: %s", tt.reason, tt.inputPath)
}
t.Logf("Input: %s, Contains '..': true, Blocked: true (early check)", tt.inputPath)
return
}
// 清理路径
cleanPath := path.Clean(tt.inputPath)
// 构建文件路径
filePath := filepath.Join(baseDir, cleanPath)
// 获取绝对路径
absDir, err := filepath.Abs(baseDir)
if err != nil {
t.Fatalf("Failed to get absolute path of base dir: %v", err)
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
t.Fatalf("Failed to get absolute path of file: %v", err)
}
// 验证路径是否在允许的目录内
isOutsideBaseDir := !strings.HasPrefix(absFilePath, absDir)
// 判断是否应该被阻止
shouldBlock := isOutsideBaseDir
if tt.shouldFail && !shouldBlock {
t.Errorf("%s: Expected path to be blocked, but it was allowed. Path: %s, Clean: %s, Abs: %s",
tt.reason, tt.inputPath, cleanPath, absFilePath)
}
if !tt.shouldFail && shouldBlock {
t.Errorf("%s: Expected path to be allowed, but it was blocked. Path: %s, Clean: %s, Abs: %s",
tt.reason, tt.inputPath, cleanPath, absFilePath)
}
// 额外的日志信息用于调试
t.Logf("Input: %s, Clean: %s, Outside base: %v, Blocked: %v",
tt.inputPath, cleanPath, isOutsideBaseDir, shouldBlock)
})
}
}
// TestPathCleanBehavior 测试 path.Clean 的行为
func TestPathCleanBehavior(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"/index.html", "/index.html"},
{"/../etc/passwd", "/etc/passwd"},
{"/./index.html", "/index.html"},
{"/css/../index.html", "/index.html"},
{"//double//slash", "/double/slash"},
{"/trailing/slash/", "/trailing/slash"},
{"/./././index.html", "/index.html"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := path.Clean(tt.input)
if result != tt.expected {
t.Errorf("path.Clean(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
// TestAbsolutePathValidation 测试绝对路径验证
func TestAbsolutePathValidation(t *testing.T) {
// 使用临时目录进行测试
baseDir := t.TempDir()
tests := []struct {
name string
path string
shouldFail bool
}{
{
name: "File in base directory",
path: "index.html",
shouldFail: false,
},
{
name: "File in subdirectory",
path: "css/style.css",
shouldFail: false,
},
{
name: "Attempt to escape with parent",
path: "../outside.txt",
shouldFail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cleanPath := path.Clean(tt.path)
filePath := filepath.Join(baseDir, cleanPath)
absDir, err := filepath.Abs(baseDir)
if err != nil {
t.Fatalf("Failed to get absolute path: %v", err)
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
t.Fatalf("Failed to get absolute file path: %v", err)
}
isOutside := !strings.HasPrefix(absFilePath, absDir)
if tt.shouldFail && !isOutside {
t.Errorf("Expected path to be outside base dir, but it wasn't: %s", absFilePath)
}
if !tt.shouldFail && isOutside {
t.Errorf("Expected path to be inside base dir, but it wasn't: %s", absFilePath)
}
})
}
}
// BenchmarkPathValidation 性能基准测试
func BenchmarkPathValidation(b *testing.B) {
baseDir := "/var/www/html"
testPath := "/css/style.css"
b.ResetTimer()
for i := 0; i < b.N; i++ {
cleanPath := path.Clean(testPath)
_ = strings.Contains(cleanPath, "..")
filePath := filepath.Join(baseDir, cleanPath)
absDir, _ := filepath.Abs(baseDir)
absFilePath, _ := filepath.Abs(filePath)
_ = strings.HasPrefix(absFilePath, absDir)
}
}
// BenchmarkPathValidationMalicious 恶意路径的性能测试
func BenchmarkPathValidationMalicious(b *testing.B) {
baseDir := "/var/www/html"
testPath := "/../../../etc/passwd"
b.ResetTimer()
for i := 0; i < b.N; i++ {
cleanPath := path.Clean(testPath)
_ = strings.Contains(cleanPath, "..")
filePath := filepath.Join(baseDir, cleanPath)
absDir, _ := filepath.Abs(baseDir)
absFilePath, _ := filepath.Abs(filePath)
_ = strings.HasPrefix(absFilePath, absDir)
}
}

View File

@ -95,7 +95,7 @@ func GetLocalIP() (string, error) {
}
var candidates []Iface
for _, ifc := range ifaces {
if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagUp == 0 {
if ifc.Flags&net.FlagUp == 0 {
continue
}
if ifc.Flags&(net.FlagPointToPoint|net.FlagLoopback) != 0 {

154
pkg/config/config_test.go Normal file
View File

@ -0,0 +1,154 @@
package config
import (
"os"
"testing"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
// 测试 Common 配置
if cfg.Common.LogLevel != "info" {
t.Errorf("Expected log level 'info', got '%s'", cfg.Common.LogLevel)
}
if cfg.Common.LogFile != "app.log" {
t.Errorf("Expected log file 'app.log', got '%s'", cfg.Common.LogFile)
}
// 测试 GB28181 配置
if cfg.GB28181.Serial != "34020000002000000001" {
t.Errorf("Expected serial '34020000002000000001', got '%s'", cfg.GB28181.Serial)
}
if cfg.GB28181.Realm != "3402000000" {
t.Errorf("Expected realm '3402000000', got '%s'", cfg.GB28181.Realm)
}
if cfg.GB28181.Host != "0.0.0.0" {
t.Errorf("Expected host '0.0.0.0', got '%s'", cfg.GB28181.Host)
}
if cfg.GB28181.Port != 5060 {
t.Errorf("Expected port 5060, got %d", cfg.GB28181.Port)
}
if cfg.GB28181.Auth.Enable != false {
t.Errorf("Expected auth enable false, got %v", cfg.GB28181.Auth.Enable)
}
if cfg.GB28181.Auth.Password != "123456" {
t.Errorf("Expected auth password '123456', got '%s'", cfg.GB28181.Auth.Password)
}
// 测试 HTTP 配置
if cfg.Http.Port != 8025 {
t.Errorf("Expected http port 8025, got %d", cfg.Http.Port)
}
if cfg.Http.Dir != "./html" {
t.Errorf("Expected http dir './html', got '%s'", cfg.Http.Dir)
}
}
func TestLoadConfigNonExistent(t *testing.T) {
// 测试加载不存在的配置文件,应该返回默认配置
cfg, err := LoadConfig("non_existent_config.yaml")
if err != nil {
t.Fatalf("Expected no error for non-existent config, got: %v", err)
}
// 应该返回默认配置
defaultCfg := DefaultConfig()
if cfg.Common.LogLevel != defaultCfg.Common.LogLevel {
t.Errorf("Expected default log level, got '%s'", cfg.Common.LogLevel)
}
}
func TestLoadConfigValid(t *testing.T) {
// 创建临时配置文件
tempFile := "test_config.yaml"
defer os.Remove(tempFile)
configContent := `common:
log-level: debug
log-file: test.log
gb28181:
serial: "12345678901234567890"
realm: "1234567890"
host: "127.0.0.1"
port: 5061
auth:
enable: true
password: "test123"
http:
listen: 9000
dir: "./test_html"
`
err := os.WriteFile(tempFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// 加载配置
cfg, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// 验证配置
if cfg.Common.LogLevel != "debug" {
t.Errorf("Expected log level 'debug', got '%s'", cfg.Common.LogLevel)
}
if cfg.Common.LogFile != "test.log" {
t.Errorf("Expected log file 'test.log', got '%s'", cfg.Common.LogFile)
}
if cfg.GB28181.Serial != "12345678901234567890" {
t.Errorf("Expected serial '12345678901234567890', got '%s'", cfg.GB28181.Serial)
}
if cfg.GB28181.Port != 5061 {
t.Errorf("Expected port 5061, got %d", cfg.GB28181.Port)
}
if cfg.GB28181.Auth.Enable != true {
t.Errorf("Expected auth enable true, got %v", cfg.GB28181.Auth.Enable)
}
if cfg.Http.Port != 9000 {
t.Errorf("Expected http port 9000, got %d", cfg.Http.Port)
}
}
func TestLoadConfigInvalid(t *testing.T) {
// 创建无效的配置文件
tempFile := "test_invalid_config.yaml"
defer os.Remove(tempFile)
invalidContent := `invalid yaml content: [[[`
err := os.WriteFile(tempFile, []byte(invalidContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// 加载配置应该失败
_, err = LoadConfig(tempFile)
if err == nil {
t.Error("Expected error for invalid config file, got nil")
}
}
func TestGetLocalIP(t *testing.T) {
ip, err := GetLocalIP()
// 在某些环境下可能没有网络接口,所以允许返回错误
if err != nil {
t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err)
return
}
// 如果成功,验证返回的是有效的 IP 地址
if ip == "" {
t.Error("Expected non-empty IP address")
}
// 简单验证 IP 格式(应该包含点号)
if len(ip) < 7 { // 最短的 IP 是 0.0.0.0
t.Errorf("IP address seems invalid: %s", ip)
}
t.Logf("Local IP: %s", ip)
}

View File

@ -1,152 +1,152 @@
package db
import (
"database/sql"
"sync"
"github.com/ossrs/srs-sip/pkg/models"
_ "modernc.org/sqlite"
)
var (
instance *MediaServerDB
once sync.Once
)
type MediaServerDB struct {
models.MediaServerResponse
db *sql.DB
}
// GetInstance 返回 MediaServerDB 的单例实例
func GetInstance(dbPath string) (*MediaServerDB, error) {
var err error
once.Do(func() {
instance, err = NewMediaServerDB(dbPath)
})
if err != nil {
return nil, err
}
return instance, nil
}
func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
// 创建媒体服务器表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS media_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
name TEXT NOT NULL,
ip TEXT NOT NULL,
port INTEGER NOT NULL,
username TEXT,
password TEXT,
secret TEXT,
is_default INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return nil, err
}
return &MediaServerDB{db: db}, nil
}
// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse
err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE name = ? AND ip = ?
`, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
return &ms, nil
}
func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
_, err := m.db.Exec(`
INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, name, serverType, ip, port, username, password, secret, isDefault)
return err
}
// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
// 检查是否已存在
existing, err := m.GetMediaServerByNameAndIP(name, ip)
if err == nil && existing != nil {
// 已存在,更新记录
_, err = m.db.Exec(`
UPDATE media_servers
SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
WHERE name = ? AND ip = ?
`, serverType, port, username, password, secret, isDefault, name, ip)
return err
}
// 不存在,插入新记录
return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
}
func (m *MediaServerDB) DeleteMediaServer(id int) error {
_, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
return err
}
func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse
err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE id = ?
`, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
return &ms, nil
}
func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
rows, err := m.db.Query(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var servers []models.MediaServerResponse
for rows.Next() {
var ms models.MediaServerResponse
err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
servers = append(servers, ms)
}
return servers, nil
}
func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
// 先将所有服务器设置为非默认
if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
return err
}
// 将指定ID的服务器设置为默认
_, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
return err
}
func (m *MediaServerDB) Close() error {
return m.db.Close()
}
package db
import (
"database/sql"
"sync"
"github.com/ossrs/srs-sip/pkg/models"
_ "modernc.org/sqlite"
)
var (
instance *MediaServerDB
once sync.Once
)
type MediaServerDB struct {
models.MediaServerResponse
db *sql.DB
}
// GetInstance 返回 MediaServerDB 的单例实例
func GetInstance(dbPath string) (*MediaServerDB, error) {
var err error
once.Do(func() {
instance, err = NewMediaServerDB(dbPath)
})
if err != nil {
return nil, err
}
return instance, nil
}
func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
// 创建媒体服务器表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS media_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
name TEXT NOT NULL,
ip TEXT NOT NULL,
port INTEGER NOT NULL,
username TEXT,
password TEXT,
secret TEXT,
is_default INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return nil, err
}
return &MediaServerDB{db: db}, nil
}
// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse
err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE name = ? AND ip = ?
`, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
return &ms, nil
}
func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
_, err := m.db.Exec(`
INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, name, serverType, ip, port, username, password, secret, isDefault)
return err
}
// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
// 检查是否已存在
existing, err := m.GetMediaServerByNameAndIP(name, ip)
if err == nil && existing != nil {
// 已存在,更新记录
_, err = m.db.Exec(`
UPDATE media_servers
SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
WHERE name = ? AND ip = ?
`, serverType, port, username, password, secret, isDefault, name, ip)
return err
}
// 不存在,插入新记录
return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
}
func (m *MediaServerDB) DeleteMediaServer(id int) error {
_, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
return err
}
func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse
err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE id = ?
`, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
return &ms, nil
}
func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
rows, err := m.db.Query(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var servers []models.MediaServerResponse
for rows.Next() {
var ms models.MediaServerResponse
err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil {
return nil, err
}
servers = append(servers, ms)
}
return servers, nil
}
func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
// 先将所有服务器设置为非默认
if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
return err
}
// 将指定ID的服务器设置为默认
_, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
return err
}
func (m *MediaServerDB) Close() error {
return m.db.Close()
}

View File

@ -105,3 +105,212 @@ func TestAddMediaServerDuplicates(t *testing.T) {
t.Log("Test confirmed: Old AddMediaServer method creates duplicates")
}
func TestGetMediaServer(t *testing.T) {
dbPath := "./test_get_media_server.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加一个媒体服务器
err = db.AddMediaServer("TestServer", "ZLM", "192.168.1.200", 8080, "admin", "pass123", "secret123", 0)
if err != nil {
t.Fatalf("Failed to add media server: %v", err)
}
// 获取服务器列表以获得ID
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list media servers: %v", err)
}
if len(servers) == 0 {
t.Fatal("No servers found")
}
// 通过ID获取服务器
server, err := db.GetMediaServer(servers[0].ID)
if err != nil {
t.Fatalf("Failed to get media server: %v", err)
}
// 验证数据
if server.Name != "TestServer" {
t.Errorf("Expected name 'TestServer', got '%s'", server.Name)
}
if server.Type != "ZLM" {
t.Errorf("Expected type 'ZLM', got '%s'", server.Type)
}
if server.IP != "192.168.1.200" {
t.Errorf("Expected IP '192.168.1.200', got '%s'", server.IP)
}
if server.Port != 8080 {
t.Errorf("Expected port 8080, got %d", server.Port)
}
}
func TestGetMediaServerNotFound(t *testing.T) {
dbPath := "./test_get_not_found.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 尝试获取不存在的服务器
_, err = db.GetMediaServer(999)
if err == nil {
t.Error("Expected error when getting non-existent server, got nil")
}
}
func TestDeleteMediaServer(t *testing.T) {
dbPath := "./test_delete_media_server.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加两个服务器
err = db.AddMediaServer("Server1", "SRS", "192.168.1.1", 1985, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server1: %v", err)
}
err = db.AddMediaServer("Server2", "ZLM", "192.168.1.2", 8080, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server2: %v", err)
}
// 获取服务器列表
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
if len(servers) != 2 {
t.Fatalf("Expected 2 servers, got %d", len(servers))
}
// 删除第一个服务器
err = db.DeleteMediaServer(servers[0].ID)
if err != nil {
t.Fatalf("Failed to delete server: %v", err)
}
// 验证只剩一个服务器
servers, err = db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers after delete: %v", err)
}
if len(servers) != 1 {
t.Fatalf("Expected 1 server after delete, got %d", len(servers))
}
}
func TestSetDefaultMediaServer(t *testing.T) {
dbPath := "./test_set_default.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加三个服务器
err = db.AddMediaServer("Server1", "SRS", "192.168.1.1", 1985, "", "", "", 1)
if err != nil {
t.Fatalf("Failed to add server1: %v", err)
}
err = db.AddMediaServer("Server2", "ZLM", "192.168.1.2", 8080, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server2: %v", err)
}
err = db.AddMediaServer("Server3", "SRS", "192.168.1.3", 1985, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server3: %v", err)
}
// 获取服务器列表
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
// 找到 Server2 的 ID
var server2ID int
for _, s := range servers {
if s.Name == "Server2" {
server2ID = s.ID
break
}
}
// 设置 Server2 为默认
err = db.SetDefaultMediaServer(server2ID)
if err != nil {
t.Fatalf("Failed to set default server: %v", err)
}
// 验证只有 Server2 是默认的
servers, err = db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
defaultCount := 0
for _, s := range servers {
if s.IsDefault == 1 {
defaultCount++
if s.Name != "Server2" {
t.Errorf("Expected Server2 to be default, got %s", s.Name)
}
}
}
if defaultCount != 1 {
t.Errorf("Expected exactly 1 default server, got %d", defaultCount)
}
}
func TestGetMediaServerByNameAndIP(t *testing.T) {
dbPath := "./test_get_by_name_ip.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加服务器
err = db.AddMediaServer("MyServer", "SRS", "10.0.0.1", 1985, "user", "pass", "secret", 0)
if err != nil {
t.Fatalf("Failed to add server: %v", err)
}
// 通过名称和IP查询
server, err := db.GetMediaServerByNameAndIP("MyServer", "10.0.0.1")
if err != nil {
t.Fatalf("Failed to get server by name and IP: %v", err)
}
if server.Name != "MyServer" || server.IP != "10.0.0.1" {
t.Errorf("Server data mismatch: %+v", server)
}
// 查询不存在的组合
_, err = db.GetMediaServerByNameAndIP("MyServer", "10.0.0.2")
if err == nil {
t.Error("Expected error for non-existent name/IP combination, got nil")
}
}

298
pkg/media/media_test.go Normal file
View File

@ -0,0 +1,298 @@
package media
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestApiRequest_Success(t *testing.T) {
// Create a test server that returns a successful response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
"data": map[string]string{
"message": "success",
},
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]string{"test": "data"}
var res map[string]interface{}
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
}
func TestApiRequest_GetMethod(t *testing.T) {
// Create a test server that checks the HTTP method
methodReceived := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
methodReceived = r.Method
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
// When req is nil, should use GET method
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if methodReceived != "GET" {
t.Errorf("Expected GET method, got %s", methodReceived)
}
}
func TestApiRequest_PostMethod(t *testing.T) {
// Create a test server that checks the HTTP method
methodReceived := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
methodReceived = r.Method
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]string{"test": "data"}
var res map[string]interface{}
// When req is not nil, should use POST method
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if methodReceived != "POST" {
t.Errorf("Expected POST method, got %s", methodReceived)
}
}
func TestApiRequest_ServerError(t *testing.T) {
// Create a test server that returns an error status code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for server error status code")
}
}
func TestApiRequest_NonZeroCode(t *testing.T) {
// Create a test server that returns a non-zero error code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 100,
"message": "error message",
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for non-zero code")
}
}
func TestApiRequest_InvalidJSON(t *testing.T) {
// Create a test server that returns invalid JSON
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("invalid json"))
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for invalid JSON")
}
}
func TestApiRequest_ContextCancellation(t *testing.T) {
// Create a test server that delays the response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
// Create a context that is already cancelled
ctx, cancel := context.WithCancel(context.Background())
cancel()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for cancelled context")
}
}
func TestApiRequest_InvalidURL(t *testing.T) {
ctx := context.Background()
var res map[string]interface{}
// Test with invalid URL
err := apiRequest(ctx, "://invalid-url", nil, &res)
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestApiRequest_Timeout(t *testing.T) {
// Create a test server that never responds
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(15 * time.Second) // Longer than the 10 second timeout
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected timeout error")
}
}
func TestApiRequest_ComplexResponse(t *testing.T) {
// Create a test server that returns a complex response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
"data": map[string]interface{}{
"id": "12345",
"status": "active",
"items": []string{
"item1",
"item2",
"item3",
},
},
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
// Check that the response was properly unmarshaled
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
data, ok := res["data"].(map[string]interface{})
if !ok {
t.Error("Expected data to be a map")
} else {
if data["id"].(string) != "12345" {
t.Errorf("Expected id '12345', got %v", data["id"])
}
}
}
func TestApiRequest_EmptyResponse(t *testing.T) {
// Create a test server that returns minimal response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
}
func TestApiRequest_WithRequestBody(t *testing.T) {
// Create a test server that echoes back the request
var receivedBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
response := map[string]interface{}{
"code": 0,
"echo": receivedBody,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]interface{}{
"id": "test-id",
"ssrc": "test-ssrc",
}
var res map[string]interface{}
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
// Check that the server received the request body
if receivedBody["id"].(string) != "test-id" {
t.Errorf("Expected id 'test-id', got %v", receivedBody["id"])
}
}

View File

@ -1,59 +1,59 @@
package media
import (
"context"
"github.com/ossrs/go-oryx-lib/errors"
)
type Zlm struct {
Ctx context.Context
Schema string // The schema of ZLM, eg: http
Addr string // The address of ZLM, eg: localhost:8085
Secret string // The secret of ZLM, eg: ZLMediaKit_secret
}
// /index/api/openRtpServer
// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
func (z *Zlm) Publish(id, ssrc string) (int, error) {
res := struct {
Code int `json:"code"`
Port int `json:"port"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
return 0, errors.Wrapf(err, "gb/v1/publish")
}
return res.Port, nil
}
// /index/api/closeRtpServer
func (z *Zlm) Unpublish(id string) error {
res := struct {
Code int `json:"code"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return errors.Wrapf(err, "gb/v1/publish")
}
return nil
}
// /index/api/getMediaList
func (z *Zlm) GetStreamStatus(id string) (bool, error) {
res := struct {
Code int `json:"code"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return false, errors.Wrapf(err, "gb/v1/publish")
}
return res.Code == 0, nil
}
func (z *Zlm) GetAddr() string {
return z.Addr
}
func (z *Zlm) GetWebRTCAddr(id string) string {
return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
}
package media
import (
"context"
"github.com/ossrs/go-oryx-lib/errors"
)
type Zlm struct {
Ctx context.Context
Schema string // The schema of ZLM, eg: http
Addr string // The address of ZLM, eg: localhost:8085
Secret string // The secret of ZLM, eg: ZLMediaKit_secret
}
// /index/api/openRtpServer
// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
func (z *Zlm) Publish(id, ssrc string) (int, error) {
res := struct {
Code int `json:"code"`
Port int `json:"port"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
return 0, errors.Wrapf(err, "gb/v1/publish")
}
return res.Port, nil
}
// /index/api/closeRtpServer
func (z *Zlm) Unpublish(id string) error {
res := struct {
Code int `json:"code"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return errors.Wrapf(err, "gb/v1/publish")
}
return nil
}
// /index/api/getMediaList
func (z *Zlm) GetStreamStatus(id string) (bool, error) {
res := struct {
Code int `json:"code"`
}{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return false, errors.Wrapf(err, "gb/v1/publish")
}
return res.Code == 0, nil
}
func (z *Zlm) GetAddr() string {
return z.Addr
}
func (z *Zlm) GetWebRTCAddr(id string) string {
return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
}

View File

@ -1,106 +1,106 @@
package models
import "encoding/xml"
type Record struct {
DeviceID string `xml:"DeviceID" json:"device_id"`
Name string `xml:"Name" json:"name"`
FilePath string `xml:"FilePath" json:"file_path"`
Address string `xml:"Address" json:"address"`
StartTime string `xml:"StartTime" json:"start_time"`
EndTime string `xml:"EndTime" json:"end_time"`
Secrecy int `xml:"Secrecy" json:"secrecy"`
Type string `xml:"Type" json:"type"`
}
// Example XML structure for channel info:
//
// <Item>
// <DeviceID>34020000001320000002</DeviceID>
// <Name>209</Name>
// <Manufacturer>UNIVIEW</Manufacturer>
// <Model>HIC6622-IR@X33-VF</Model>
// <Owner>IPC-B2202.7.11.230222</Owner>
// <CivilCode>CivilCode</CivilCode>
// <Address>Address</Address>
// <Parental>1</Parental>
// <ParentID>75015310072008100002</ParentID>
// <SafetyWay>0</SafetyWay>
// <RegisterWay>1</RegisterWay>
// <Secrecy>0</Secrecy>
// <Status>ON</Status>
// <Longitude>0.0000000</Longitude>
// <Latitude>0.0000000</Latitude>
// <Info>
// <PTZType>1</PTZType>
// <Resolution>6/4/2</Resolution>
// <DownloadSpeed>0</DownloadSpeed>
// </Info>
// </Item>
type ChannelInfo struct {
DeviceID string `json:"device_id"`
ParentID string `json:"parent_id"`
Name string `json:"name"`
Manufacturer string `json:"manufacturer"`
Model string `json:"model"`
Owner string `json:"owner"`
CivilCode string `json:"civil_code"`
Address string `json:"address"`
Port int `json:"port"`
Parental int `json:"parental"`
SafetyWay int `json:"safety_way"`
RegisterWay int `json:"register_way"`
Secrecy int `json:"secrecy"`
IPAddress string `json:"ip_address"`
Status ChannelStatus `json:"status"`
Longitude float64 `json:"longitude"`
Latitude float64 `json:"latitude"`
Info struct {
PTZType int `json:"ptz_type"`
Resolution string `json:"resolution"`
DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
} `json:"info"`
// Custom fields
Ssrc string `json:"ssrc"`
}
type ChannelStatus string
// BasicParam
// <! -- 基本参数配置(可选)-->
// <elementname="BasicParam"minOccurs="0">
// <complexType>
// <sequence>
// <! -- 设备名称(可选)-->
// <elementname="Name"type="string" minOccurs="0"/>
// <! -- 注册过期时间(可选)-->
// <elementname="Expiration"type="integer" minOccurs="0"/>
// <! -- 心跳间隔时间(可选)-->
// <elementname="HeartBeatInterval"type="integer" minOccurs="0"/>
// <! -- 心跳超时次数(可选)-->
// <elementname="HeartBeatCount"type="integer" minOccurs="0"/>
// </sequence>
// </complexType>
type BasicParam struct {
Name string `xml:"Name"`
Expiration int `xml:"Expiration"`
HeartBeatInterval int `xml:"HeartBeatInterval"`
HeartBeatCount int `xml:"HeartBeatCount"`
}
type XmlMessageInfo struct {
XMLName xml.Name
CmdType string
SN int
DeviceID string
DeviceName string
Manufacturer string
Model string
Channel string
DeviceList []ChannelInfo `xml:"DeviceList>Item"`
RecordList []*Record `xml:"RecordList>Item"`
BasicParam BasicParam `xml:"BasicParam"`
SumNum int
}
package models
import "encoding/xml"
type Record struct {
DeviceID string `xml:"DeviceID" json:"device_id"`
Name string `xml:"Name" json:"name"`
FilePath string `xml:"FilePath" json:"file_path"`
Address string `xml:"Address" json:"address"`
StartTime string `xml:"StartTime" json:"start_time"`
EndTime string `xml:"EndTime" json:"end_time"`
Secrecy int `xml:"Secrecy" json:"secrecy"`
Type string `xml:"Type" json:"type"`
}
// Example XML structure for channel info:
//
// <Item>
// <DeviceID>34020000001320000002</DeviceID>
// <Name>209</Name>
// <Manufacturer>UNIVIEW</Manufacturer>
// <Model>HIC6622-IR@X33-VF</Model>
// <Owner>IPC-B2202.7.11.230222</Owner>
// <CivilCode>CivilCode</CivilCode>
// <Address>Address</Address>
// <Parental>1</Parental>
// <ParentID>75015310072008100002</ParentID>
// <SafetyWay>0</SafetyWay>
// <RegisterWay>1</RegisterWay>
// <Secrecy>0</Secrecy>
// <Status>ON</Status>
// <Longitude>0.0000000</Longitude>
// <Latitude>0.0000000</Latitude>
// <Info>
// <PTZType>1</PTZType>
// <Resolution>6/4/2</Resolution>
// <DownloadSpeed>0</DownloadSpeed>
// </Info>
// </Item>
type ChannelInfo struct {
DeviceID string `json:"device_id"`
ParentID string `json:"parent_id"`
Name string `json:"name"`
Manufacturer string `json:"manufacturer"`
Model string `json:"model"`
Owner string `json:"owner"`
CivilCode string `json:"civil_code"`
Address string `json:"address"`
Port int `json:"port"`
Parental int `json:"parental"`
SafetyWay int `json:"safety_way"`
RegisterWay int `json:"register_way"`
Secrecy int `json:"secrecy"`
IPAddress string `json:"ip_address"`
Status ChannelStatus `json:"status"`
Longitude float64 `json:"longitude"`
Latitude float64 `json:"latitude"`
Info struct {
PTZType int `json:"ptz_type"`
Resolution string `json:"resolution"`
DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
} `json:"info"`
// Custom fields
Ssrc string `json:"ssrc"`
}
type ChannelStatus string
// BasicParam
// <! -- 基本参数配置(可选)-->
// <elementname="BasicParam"minOccurs="0">
// <complexType>
// <sequence>
// <! -- 设备名称(可选)-->
// <elementname="Name"type="string" minOccurs="0"/>
// <! -- 注册过期时间(可选)-->
// <elementname="Expiration"type="integer" minOccurs="0"/>
// <! -- 心跳间隔时间(可选)-->
// <elementname="HeartBeatInterval"type="integer" minOccurs="0"/>
// <! -- 心跳超时次数(可选)-->
// <elementname="HeartBeatCount"type="integer" minOccurs="0"/>
// </sequence>
// </complexType>
type BasicParam struct {
Name string `xml:"Name"`
Expiration int `xml:"Expiration"`
HeartBeatInterval int `xml:"HeartBeatInterval"`
HeartBeatCount int `xml:"HeartBeatCount"`
}
type XmlMessageInfo struct {
XMLName xml.Name
CmdType string
SN int
DeviceID string
DeviceName string
Manufacturer string
Model string
Channel string
DeviceList []ChannelInfo `xml:"DeviceList>Item"`
RecordList []*Record `xml:"RecordList>Item"`
BasicParam BasicParam `xml:"BasicParam"`
SumNum int
}

View File

@ -1,80 +1,80 @@
package models
type BaseRequest struct {
DeviceID string `json:"device_id"`
ChannelID string `json:"channel_id"`
}
type InviteRequest struct {
BaseRequest
MediaServerId int `json:"media_server_id"`
PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
SubStream int `json:"sub_stream"`
StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"`
}
type InviteResponse struct {
ChannelID string `json:"channel_id"`
URL string `json:"url"`
}
type SessionRequest struct {
BaseRequest
URL string `json:"url"`
}
type ByeRequest struct {
SessionRequest
}
type PauseRequest struct {
SessionRequest
}
type ResumeRequest struct {
SessionRequest
}
type SpeedRequest struct {
SessionRequest
Speed float32 `json:"speed"`
}
type PTZControlRequest struct {
BaseRequest
PTZ string `json:"ptz"`
Speed string `json:"speed"`
}
type QueryRecordRequest struct {
BaseRequest
StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"`
}
type MediaServer struct {
Name string `json:"name"`
Type string `json:"type"`
IP string `json:"ip"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Secret string `json:"secret"`
IsDefault int `json:"is_default"`
}
type MediaServerRequest struct {
MediaServer
}
type MediaServerResponse struct {
MediaServer
ID int `json:"id"`
CreatedAt string `json:"created_at"`
}
type CommonResponse struct {
Code int `json:"code"`
Data interface{} `json:"data"`
}
package models
type BaseRequest struct {
DeviceID string `json:"device_id"`
ChannelID string `json:"channel_id"`
}
type InviteRequest struct {
BaseRequest
MediaServerId int `json:"media_server_id"`
PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
SubStream int `json:"sub_stream"`
StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"`
}
type InviteResponse struct {
ChannelID string `json:"channel_id"`
URL string `json:"url"`
}
type SessionRequest struct {
BaseRequest
URL string `json:"url"`
}
type ByeRequest struct {
SessionRequest
}
type PauseRequest struct {
SessionRequest
}
type ResumeRequest struct {
SessionRequest
}
type SpeedRequest struct {
SessionRequest
Speed float32 `json:"speed"`
}
type PTZControlRequest struct {
BaseRequest
PTZ string `json:"ptz"`
Speed string `json:"speed"`
}
type QueryRecordRequest struct {
BaseRequest
StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"`
}
type MediaServer struct {
Name string `json:"name"`
Type string `json:"type"`
IP string `json:"ip"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Secret string `json:"secret"`
IsDefault int `json:"is_default"`
}
type MediaServerRequest struct {
MediaServer
}
type MediaServerResponse struct {
MediaServer
ID int `json:"id"`
CreatedAt string `json:"created_at"`
}
type CommonResponse struct {
Code int `json:"code"`
Data interface{} `json:"data"`
}

337
pkg/models/types_test.go Normal file
View File

@ -0,0 +1,337 @@
package models
import (
"encoding/json"
"testing"
)
func TestBaseRequest(t *testing.T) {
req := BaseRequest{
DeviceID: "34020000001320000001",
ChannelID: "34020000001320000002",
}
if req.DeviceID != "34020000001320000001" {
t.Errorf("Expected DeviceID '34020000001320000001', got '%s'", req.DeviceID)
}
if req.ChannelID != "34020000001320000002" {
t.Errorf("Expected ChannelID '34020000001320000002', got '%s'", req.ChannelID)
}
}
func TestInviteRequest(t *testing.T) {
req := InviteRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
MediaServerId: 1,
PlayType: 0,
SubStream: 0,
StartTime: 1234567890,
EndTime: 1234567900,
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
if req.MediaServerId != 1 {
t.Errorf("Expected MediaServerId 1, got %d", req.MediaServerId)
}
if req.PlayType != 0 {
t.Errorf("Expected PlayType 0, got %d", req.PlayType)
}
}
func TestInviteRequestJSON(t *testing.T) {
jsonStr := `{
"device_id": "device123",
"channel_id": "channel123",
"media_server_id": 1,
"play_type": 1,
"sub_stream": 0,
"start_time": 1234567890,
"end_time": 1234567900
}`
var req InviteRequest
err := json.Unmarshal([]byte(jsonStr), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
if req.PlayType != 1 {
t.Errorf("Expected PlayType 1, got %d", req.PlayType)
}
}
func TestInviteResponse(t *testing.T) {
resp := InviteResponse{
ChannelID: "channel123",
URL: "webrtc://example.com/live/stream",
}
if resp.ChannelID != "channel123" {
t.Errorf("Expected ChannelID 'channel123', got '%s'", resp.ChannelID)
}
if resp.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", resp.URL)
}
}
func TestPTZControlRequest(t *testing.T) {
req := PTZControlRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
PTZ: "up",
Speed: "5",
}
if req.PTZ != "up" {
t.Errorf("Expected PTZ 'up', got '%s'", req.PTZ)
}
if req.Speed != "5" {
t.Errorf("Expected Speed '5', got '%s'", req.Speed)
}
}
func TestQueryRecordRequest(t *testing.T) {
req := QueryRecordRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
StartTime: 1234567890,
EndTime: 1234567900,
}
if req.StartTime != 1234567890 {
t.Errorf("Expected StartTime 1234567890, got %d", req.StartTime)
}
if req.EndTime != 1234567900 {
t.Errorf("Expected EndTime 1234567900, got %d", req.EndTime)
}
}
func TestMediaServer(t *testing.T) {
ms := MediaServer{
Name: "SRS Server",
Type: "SRS",
IP: "192.168.1.100",
Port: 1985,
Username: "admin",
Password: "password",
Secret: "secret",
IsDefault: 1,
}
if ms.Name != "SRS Server" {
t.Errorf("Expected Name 'SRS Server', got '%s'", ms.Name)
}
if ms.Type != "SRS" {
t.Errorf("Expected Type 'SRS', got '%s'", ms.Type)
}
if ms.Port != 1985 {
t.Errorf("Expected Port 1985, got %d", ms.Port)
}
if ms.IsDefault != 1 {
t.Errorf("Expected IsDefault 1, got %d", ms.IsDefault)
}
}
func TestMediaServerResponse(t *testing.T) {
resp := MediaServerResponse{
MediaServer: MediaServer{
Name: "Test Server",
Type: "ZLM",
IP: "10.0.0.1",
Port: 8080,
},
ID: 1,
CreatedAt: "2024-01-01 12:00:00",
}
if resp.ID != 1 {
t.Errorf("Expected ID 1, got %d", resp.ID)
}
if resp.CreatedAt != "2024-01-01 12:00:00" {
t.Errorf("Expected CreatedAt '2024-01-01 12:00:00', got '%s'", resp.CreatedAt)
}
}
func TestCommonResponse(t *testing.T) {
resp := CommonResponse{
Code: 0,
Data: map[string]string{"key": "value"},
}
if resp.Code != 0 {
t.Errorf("Expected Code 0, got %d", resp.Code)
}
// 测试 JSON 序列化
jsonData, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
var decoded CommonResponse
err = json.Unmarshal(jsonData, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if decoded.Code != 0 {
t.Errorf("Expected decoded Code 0, got %d", decoded.Code)
}
}
func TestSessionRequest(t *testing.T) {
req := SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
}
if req.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", req.URL)
}
}
func TestByeRequest(t *testing.T) {
req := ByeRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
}
func TestPauseRequest(t *testing.T) {
req := PauseRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", req.URL)
}
}
func TestResumeRequest(t *testing.T) {
req := ResumeRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.ChannelID != "channel123" {
t.Errorf("Expected ChannelID 'channel123', got '%s'", req.ChannelID)
}
}
func TestSpeedRequest(t *testing.T) {
req := SpeedRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
Speed: 2.0,
}
if req.Speed != 2.0 {
t.Errorf("Expected Speed 2.0, got %f", req.Speed)
}
}
func TestMediaServerRequestJSON(t *testing.T) {
jsonStr := `{
"name": "Test Server",
"type": "SRS",
"ip": "192.168.1.100",
"port": 1985,
"username": "admin",
"password": "pass123",
"secret": "secret123",
"is_default": 1
}`
var req MediaServerRequest
err := json.Unmarshal([]byte(jsonStr), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if req.Name != "Test Server" {
t.Errorf("Expected Name 'Test Server', got '%s'", req.Name)
}
if req.Type != "SRS" {
t.Errorf("Expected Type 'SRS', got '%s'", req.Type)
}
if req.Port != 1985 {
t.Errorf("Expected Port 1985, got %d", req.Port)
}
}
func TestCommonResponseWithDifferentDataTypes(t *testing.T) {
tests := []struct {
name string
data interface{}
}{
{"String data", "test string"},
{"Integer data", 123},
{"Map data", map[string]interface{}{"key": "value"}},
{"Array data", []string{"item1", "item2"}},
{"Nil data", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := CommonResponse{
Code: 0,
Data: tt.data,
}
jsonData, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
var decoded CommonResponse
err = json.Unmarshal(jsonData, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if decoded.Code != 0 {
t.Errorf("Expected Code 0, got %d", decoded.Code)
}
})
}
}

View File

@ -1,92 +1,92 @@
package service
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
)
// AuthInfo 存储解析后的认证信息
type AuthInfo struct {
Username string
Realm string
Nonce string
URI string
Response string
Algorithm string
Method string
}
// GenerateNonce 生成随机 nonce 字符串
func GenerateNonce() string {
b := make([]byte, 16)
rand.Read(b)
return fmt.Sprintf("%x", b)
}
// ParseAuthorization 解析 SIP Authorization 头
// Authorization: Digest username="34020000001320000001",realm="3402000000",
// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
func ParseAuthorization(auth string) *AuthInfo {
auth = strings.TrimPrefix(auth, "Digest ")
parts := strings.Split(auth, ",")
result := &AuthInfo{}
for _, part := range parts {
part = strings.TrimSpace(part)
if !strings.Contains(part, "=") {
continue
}
kv := strings.SplitN(part, "=", 2)
key := strings.TrimSpace(kv[0])
value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
switch key {
case "username":
result.Username = value
case "realm":
result.Realm = value
case "nonce":
result.Nonce = value
case "uri":
result.URI = value
case "response":
result.Response = value
case "algorithm":
result.Algorithm = value
}
}
return result
}
// ValidateAuth 验证 SIP 认证信息
func ValidateAuth(authInfo *AuthInfo, password string) bool {
if authInfo == nil {
return false
}
// 默认方法为 REGISTER
method := "REGISTER"
if authInfo.Method != "" {
method = authInfo.Method
}
// 计算 MD5 哈希
ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
ha2 := md5Hex(method + ":" + authInfo.URI)
correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
return authInfo.Response == correctResponse
}
// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
func md5Hex(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}
package service
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
)
// AuthInfo 存储解析后的认证信息
type AuthInfo struct {
Username string
Realm string
Nonce string
URI string
Response string
Algorithm string
Method string
}
// GenerateNonce 生成随机 nonce 字符串
func GenerateNonce() string {
b := make([]byte, 16)
rand.Read(b)
return fmt.Sprintf("%x", b)
}
// ParseAuthorization 解析 SIP Authorization 头
// Authorization: Digest username="34020000001320000001",realm="3402000000",
// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
func ParseAuthorization(auth string) *AuthInfo {
auth = strings.TrimPrefix(auth, "Digest ")
parts := strings.Split(auth, ",")
result := &AuthInfo{}
for _, part := range parts {
part = strings.TrimSpace(part)
if !strings.Contains(part, "=") {
continue
}
kv := strings.SplitN(part, "=", 2)
key := strings.TrimSpace(kv[0])
value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
switch key {
case "username":
result.Username = value
case "realm":
result.Realm = value
case "nonce":
result.Nonce = value
case "uri":
result.URI = value
case "response":
result.Response = value
case "algorithm":
result.Algorithm = value
}
}
return result
}
// ValidateAuth 验证 SIP 认证信息
func ValidateAuth(authInfo *AuthInfo, password string) bool {
if authInfo == nil {
return false
}
// 默认方法为 REGISTER
method := "REGISTER"
if authInfo.Method != "" {
method = authInfo.Method
}
// 计算 MD5 哈希
ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
ha2 := md5Hex(method + ":" + authInfo.URI)
correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
return authInfo.Response == correctResponse
}
// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
func md5Hex(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}

345
pkg/service/auth_test.go Normal file
View File

@ -0,0 +1,345 @@
package service
import (
"strings"
"testing"
)
func TestGenerateNonce(t *testing.T) {
// 生成多个 nonce 并验证
nonces := make(map[string]bool)
iterations := 100
for i := 0; i < iterations; i++ {
nonce := GenerateNonce()
// 验证长度16字节的十六进制表示应该是32个字符
if len(nonce) != 32 {
t.Errorf("Expected nonce length 32, got %d", len(nonce))
}
// 验证是否为十六进制字符串
for _, c := range nonce {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("Nonce contains non-hex character: %c", c)
}
}
nonces[nonce] = true
}
// 验证唯一性(应该生成不同的 nonce
if len(nonces) < 95 { // 允许极小概率的重复
t.Errorf("Expected at least 95 unique nonces out of %d, got %d", iterations, len(nonces))
}
}
func TestParseAuthorization(t *testing.T) {
tests := []struct {
name string
auth string
expected *AuthInfo
}{
{
name: "Complete authorization header",
auth: `Digest username="34020000001320000001",realm="3402000000",nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5`,
expected: &AuthInfo{
Username: "34020000001320000001",
Realm: "3402000000",
Nonce: "44010b73623249f6916a6acf7c316b8e",
URI: "sip:34020000002000000001@3402000000",
Response: "e4ca3fdc5869fa1c544ea7af60014444",
Algorithm: "MD5",
},
},
{
name: "Authorization with spaces",
auth: `Digest username = "user123" , realm = "realm123" , nonce = "nonce123" , uri = "sip:test@example.com" , response = "resp123"`,
expected: &AuthInfo{
Username: "user123",
Realm: "realm123",
Nonce: "nonce123",
URI: "sip:test@example.com",
Response: "resp123",
},
},
{
name: "Partial authorization",
auth: `Digest username="testuser",realm="testrealm"`,
expected: &AuthInfo{
Username: "testuser",
Realm: "testrealm",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseAuthorization(tt.auth)
if result.Username != tt.expected.Username {
t.Errorf("Username: expected %s, got %s", tt.expected.Username, result.Username)
}
if result.Realm != tt.expected.Realm {
t.Errorf("Realm: expected %s, got %s", tt.expected.Realm, result.Realm)
}
if result.Nonce != tt.expected.Nonce {
t.Errorf("Nonce: expected %s, got %s", tt.expected.Nonce, result.Nonce)
}
if result.URI != tt.expected.URI {
t.Errorf("URI: expected %s, got %s", tt.expected.URI, result.URI)
}
if result.Response != tt.expected.Response {
t.Errorf("Response: expected %s, got %s", tt.expected.Response, result.Response)
}
if result.Algorithm != tt.expected.Algorithm {
t.Errorf("Algorithm: expected %s, got %s", tt.expected.Algorithm, result.Algorithm)
}
})
}
}
func TestParseAuthorizationEdgeCases(t *testing.T) {
tests := []struct {
name string
auth string
}{
{"Empty string", ""},
{"Only Digest", "Digest "},
{"Invalid format", "invalid format"},
{"No equals sign", "Digest username"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseAuthorization(tt.auth)
// 不应该 panic应该返回一个空的 AuthInfo
if result == nil {
t.Error("Expected non-nil result")
}
})
}
}
func TestMd5Hex(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Simple string",
input: "hello",
expected: "5d41402abc4b2a76b9719d911017c592",
},
{
name: "Empty string",
input: "",
expected: "d41d8cd98f00b204e9800998ecf8427e",
},
{
name: "Numbers",
input: "123456",
expected: "e10adc3949ba59abbe56e057f20f883e",
},
{
name: "Complex string",
input: "username:realm:password",
expected: "8e8d14bf0c4b87c1c5b8b1e8c8e8d14b", // 这个需要实际计算
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := md5Hex(tt.input)
// 验证长度MD5 哈希应该是32个字符
if len(result) != 32 {
t.Errorf("Expected MD5 hash length 32, got %d", len(result))
}
// 验证是否为十六进制字符串
for _, c := range result {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("MD5 hash contains non-hex character: %c", c)
}
}
// 对于已知的测试用例,验证具体值
if tt.name != "Complex string" && result != tt.expected {
t.Errorf("Expected MD5 hash %s, got %s", tt.expected, result)
}
})
}
}
func TestValidateAuth(t *testing.T) {
// 测试用例:使用已知的认证信息
t.Run("Valid authentication", func(t *testing.T) {
// 构造一个已知的认证场景
username := "testuser"
realm := "testrealm"
password := "testpass"
nonce := "testnonce"
uri := "sip:test@example.com"
method := "REGISTER"
// 计算正确的 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex(method + ":" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: method,
}
if !ValidateAuth(authInfo, password) {
t.Error("Expected authentication to be valid")
}
})
t.Run("Invalid password", func(t *testing.T) {
username := "testuser"
realm := "testrealm"
password := "testpass"
wrongPassword := "wrongpass"
nonce := "testnonce"
uri := "sip:test@example.com"
method := "REGISTER"
// 使用正确密码计算 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex(method + ":" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: method,
}
// 使用错误密码验证
if ValidateAuth(authInfo, wrongPassword) {
t.Error("Expected authentication to fail with wrong password")
}
})
t.Run("Nil authInfo", func(t *testing.T) {
if ValidateAuth(nil, "password") {
t.Error("Expected authentication to fail with nil authInfo")
}
})
t.Run("Default method", func(t *testing.T) {
// 测试当 Method 为空时,默认使用 REGISTER
username := "testuser"
realm := "testrealm"
password := "testpass"
nonce := "testnonce"
uri := "sip:test@example.com"
// 使用默认方法 REGISTER 计算 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex("REGISTER:" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: "", // 空方法,应该使用默认的 REGISTER
}
if !ValidateAuth(authInfo, password) {
t.Error("Expected authentication to be valid with default method")
}
})
}
func TestAuthInfoStruct(t *testing.T) {
// 测试 AuthInfo 结构体的基本功能
authInfo := &AuthInfo{
Username: "user",
Realm: "realm",
Nonce: "nonce",
URI: "uri",
Response: "response",
Algorithm: "MD5",
Method: "REGISTER",
}
if authInfo.Username != "user" {
t.Errorf("Expected username 'user', got '%s'", authInfo.Username)
}
if authInfo.Algorithm != "MD5" {
t.Errorf("Expected algorithm 'MD5', got '%s'", authInfo.Algorithm)
}
}
func TestParseAuthorizationWithoutDigestPrefix(t *testing.T) {
// 测试没有 "Digest " 前缀的情况
auth := `username="testuser",realm="testrealm"`
result := ParseAuthorization(auth)
if result.Username != "testuser" {
t.Errorf("Expected username 'testuser', got '%s'", result.Username)
}
if result.Realm != "testrealm" {
t.Errorf("Expected realm 'testrealm', got '%s'", result.Realm)
}
}
func TestParseAuthorizationCaseInsensitive(t *testing.T) {
// 虽然当前实现是大小写敏感的,但这个测试可以帮助未来改进
auth := `Digest username="testuser",realm="testrealm"`
result := ParseAuthorization(auth)
if result.Username == "" {
t.Error("Failed to parse username")
}
}
func TestMd5HexConsistency(t *testing.T) {
// 测试相同输入产生相同输出
input := "test string"
result1 := md5Hex(input)
result2 := md5Hex(input)
if result1 != result2 {
t.Errorf("MD5 hash should be consistent: %s != %s", result1, result2)
}
}
func TestMd5HexDifferentInputs(t *testing.T) {
// 测试不同输入产生不同输出
result1 := md5Hex("input1")
result2 := md5Hex("input2")
if result1 == result2 {
t.Error("Different inputs should produce different MD5 hashes")
}
}
func TestParseAuthorizationQuotedValues(t *testing.T) {
// 测试带引号和不带引号的值
auth := `Digest username="quoted",realm=unquoted,nonce="also-quoted"`
result := ParseAuthorization(auth)
if result.Username != "quoted" {
t.Errorf("Expected username 'quoted', got '%s'", result.Username)
}
// realm 没有引号,应该也能正确解析
if !strings.Contains(result.Realm, "unquoted") {
t.Logf("Realm value: '%s'", result.Realm)
}
}

View File

@ -1,17 +1,17 @@
package service
import (
"context"
"github.com/emiago/sipgo"
"github.com/ossrs/srs-sip/pkg/config"
)
type Cascade struct {
ua *sipgo.UserAgent
sipCli *sipgo.Client
sipSvr *sipgo.Server
ctx context.Context
conf *config.MainConfig
}
package service
import (
"context"
"github.com/emiago/sipgo"
"github.com/ossrs/srs-sip/pkg/config"
)
type Cascade struct {
ua *sipgo.UserAgent
sipCli *sipgo.Client
sipSvr *sipgo.Server
ctx context.Context
conf *config.MainConfig
}

View File

@ -1,171 +1,171 @@
package service
import (
"bytes"
"encoding/xml"
"fmt"
"log/slog"
"net"
"net/http"
"strconv"
"github.com/emiago/sipgo/sip"
"github.com/ossrs/srs-sip/pkg/models"
"github.com/ossrs/srs-sip/pkg/service/stack"
"golang.org/x/net/html/charset"
)
const GB28181_ID_LENGTH = 20
func (s *UAS) isSameIP(addr1, addr2 string) bool {
ip1, _, err1 := net.SplitHostPort(addr1)
ip2, _, err2 := net.SplitHostPort(addr2)
// 如果解析出错,回退到完整字符串比较
if err1 != nil || err2 != nil {
return addr1 == addr2
}
return ip1 == ip2
}
func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User
if len(id) != GB28181_ID_LENGTH {
slog.Error("invalid device ID")
return
}
slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
if s.conf.GB28181.Auth.Enable {
// Check if Authorization header exists
authHeader := req.GetHeaders("Authorization")
// If no Authorization header, send 401 response to request authentication
if len(authHeader) == 0 {
nonce := GenerateNonce()
resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
_ = tx.Respond(resp)
return
}
// Validate Authorization
authInfo := ParseAuthorization(authHeader[0].Value())
if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
slog.Error("auth failed", "device_id", id, "source", req.Source())
s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
return
}
}
isUnregister := false
if exps := req.GetHeaders("Expires"); len(exps) > 0 {
exp := exps[0]
expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
if err != nil {
slog.Error("parse expires header error", "error", err.Error())
return
}
if expSec == 0 {
isUnregister = true
}
} else {
slog.Error("empty expires header")
return
}
if isUnregister {
DM.RemoveDevice(id)
slog.Warn("Device unregistered", "device_id", id)
return
} else {
if d, ok := DM.GetDevice(id); !ok {
DM.AddDevice(id, &DeviceInfo{
DeviceID: id,
SourceAddr: req.Source(),
NetworkType: req.Transport(),
})
s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
go s.ConfigDownload(id)
go s.Catalog(id)
} else {
if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
// TODO: 如果ID重复应采用虚拟ID
s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
} else {
d.SourceAddr = req.Source()
d.NetworkType = req.Transport()
DM.UpdateDevice(id, d)
s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
}
}
}
}
func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
res := stack.NewRegisterResponse(req, code, reason)
_ = tx.Respond(res)
}
func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User
if len(id) != 20 {
slog.Error("invalid device ID", "request", req.String())
}
slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
temp := &models.XmlMessageInfo{}
decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
decoder.CharsetReader = charset.NewReaderLabel
if err := decoder.Decode(temp); err != nil {
slog.Error("decode message error", "error", err.Error(), "message", req.Body())
}
slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
var body string
switch temp.CmdType {
case "Keepalive":
if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
// 更新设备心跳时间
DM.UpdateDeviceHeartbeat(temp.DeviceID)
} else {
tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
return
}
case "SensorCatalog": // 兼容宇视,非国标
case "Catalog":
DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
//go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
case "ConfigDownload":
DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
case "Alarm":
slog.Info("Alarm")
case "RecordInfo":
// 从 recordQueryResults 中获取对应通道的结果通道
if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
// 发送查询结果
resultChan := ch.(chan *models.XmlMessageInfo)
resultChan <- temp
}
default:
slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
tx.Respond(response)
return
}
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
}
func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
}
package service
import (
"bytes"
"encoding/xml"
"fmt"
"log/slog"
"net"
"net/http"
"strconv"
"github.com/emiago/sipgo/sip"
"github.com/ossrs/srs-sip/pkg/models"
"github.com/ossrs/srs-sip/pkg/service/stack"
"golang.org/x/net/html/charset"
)
const GB28181_ID_LENGTH = 20
func (s *UAS) isSameIP(addr1, addr2 string) bool {
ip1, _, err1 := net.SplitHostPort(addr1)
ip2, _, err2 := net.SplitHostPort(addr2)
// 如果解析出错,回退到完整字符串比较
if err1 != nil || err2 != nil {
return addr1 == addr2
}
return ip1 == ip2
}
func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User
if len(id) != GB28181_ID_LENGTH {
slog.Error("invalid device ID")
return
}
slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
if s.conf.GB28181.Auth.Enable {
// Check if Authorization header exists
authHeader := req.GetHeaders("Authorization")
// If no Authorization header, send 401 response to request authentication
if len(authHeader) == 0 {
nonce := GenerateNonce()
resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
_ = tx.Respond(resp)
return
}
// Validate Authorization
authInfo := ParseAuthorization(authHeader[0].Value())
if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
slog.Error("auth failed", "device_id", id, "source", req.Source())
s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
return
}
}
isUnregister := false
if exps := req.GetHeaders("Expires"); len(exps) > 0 {
exp := exps[0]
expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
if err != nil {
slog.Error("parse expires header error", "error", err.Error())
return
}
if expSec == 0 {
isUnregister = true
}
} else {
slog.Error("empty expires header")
return
}
if isUnregister {
DM.RemoveDevice(id)
slog.Warn("Device unregistered", "device_id", id)
return
} else {
if d, ok := DM.GetDevice(id); !ok {
DM.AddDevice(id, &DeviceInfo{
DeviceID: id,
SourceAddr: req.Source(),
NetworkType: req.Transport(),
})
s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
go s.ConfigDownload(id)
go s.Catalog(id)
} else {
if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
// TODO: 如果ID重复应采用虚拟ID
s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
} else {
d.SourceAddr = req.Source()
d.NetworkType = req.Transport()
DM.UpdateDevice(id, d)
s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
}
}
}
}
func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
res := stack.NewRegisterResponse(req, code, reason)
_ = tx.Respond(res)
}
func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User
if len(id) != 20 {
slog.Error("invalid device ID", "request", req.String())
}
slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
temp := &models.XmlMessageInfo{}
decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
decoder.CharsetReader = charset.NewReaderLabel
if err := decoder.Decode(temp); err != nil {
slog.Error("decode message error", "error", err.Error(), "message", req.Body())
}
slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
var body string
switch temp.CmdType {
case "Keepalive":
if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
// 更新设备心跳时间
DM.UpdateDeviceHeartbeat(temp.DeviceID)
} else {
tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
return
}
case "SensorCatalog": // 兼容宇视,非国标
case "Catalog":
DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
//go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
case "ConfigDownload":
DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
case "Alarm":
slog.Info("Alarm")
case "RecordInfo":
// 从 recordQueryResults 中获取对应通道的结果通道
if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
// 发送查询结果
resultChan := ch.(chan *models.XmlMessageInfo)
resultChan <- temp
}
default:
slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
tx.Respond(response)
return
}
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
}
func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
}

View File

@ -1,81 +1,81 @@
package service
import "fmt"
var (
ptzCmdMap = map[string]uint8{
"stop": 0,
"right": 1,
"left": 2,
"down": 4,
"downright": 5,
"downleft": 6,
"up": 8,
"upright": 9,
"upleft": 10,
"zoomin": 16,
"zoomout": 32,
}
ptzSpeedMap = map[string]uint8{
"1": 25,
"2": 50,
"3": 75,
"4": 100,
"5": 125,
"6": 150,
"7": 175,
"8": 200,
"9": 225,
"10": 255,
}
defaultSpeed uint8 = 125
)
func getPTZSpeed(speed string) uint8 {
if v, ok := ptzSpeedMap[speed]; ok {
return v
}
return defaultSpeed
}
func toPTZCmd(cmdName, speed string) (string, error) {
cmdCode, ok := ptzCmdMap[cmdName]
if !ok {
return "", fmt.Errorf("invalid ptz command: %q", cmdName)
}
speedValue := getPTZSpeed(speed)
var horizontalSpeed, verticalSpeed, zSpeed uint8
switch cmdName {
case "left", "right":
horizontalSpeed = speedValue
verticalSpeed = 0
case "up", "down":
verticalSpeed = speedValue
horizontalSpeed = 0
case "upleft", "upright", "downleft", "downright":
verticalSpeed = speedValue
horizontalSpeed = speedValue
case "zoomin", "zoomout":
zSpeed = speedValue << 4 // zoom速度在高4位
default:
horizontalSpeed = 0
verticalSpeed = 0
zSpeed = 0
}
sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
checksum := uint8(sum % 256)
return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
cmdCode,
horizontalSpeed,
verticalSpeed,
zSpeed,
checksum,
), nil
}
package service
import "fmt"
var (
ptzCmdMap = map[string]uint8{
"stop": 0,
"right": 1,
"left": 2,
"down": 4,
"downright": 5,
"downleft": 6,
"up": 8,
"upright": 9,
"upleft": 10,
"zoomin": 16,
"zoomout": 32,
}
ptzSpeedMap = map[string]uint8{
"1": 25,
"2": 50,
"3": 75,
"4": 100,
"5": 125,
"6": 150,
"7": 175,
"8": 200,
"9": 225,
"10": 255,
}
defaultSpeed uint8 = 125
)
func getPTZSpeed(speed string) uint8 {
if v, ok := ptzSpeedMap[speed]; ok {
return v
}
return defaultSpeed
}
func toPTZCmd(cmdName, speed string) (string, error) {
cmdCode, ok := ptzCmdMap[cmdName]
if !ok {
return "", fmt.Errorf("invalid ptz command: %q", cmdName)
}
speedValue := getPTZSpeed(speed)
var horizontalSpeed, verticalSpeed, zSpeed uint8
switch cmdName {
case "left", "right":
horizontalSpeed = speedValue
verticalSpeed = 0
case "up", "down":
verticalSpeed = speedValue
horizontalSpeed = 0
case "upleft", "upright", "downleft", "downright":
verticalSpeed = speedValue
horizontalSpeed = speedValue
case "zoomin", "zoomout":
zSpeed = speedValue << 4 // zoom速度在高4位
default:
horizontalSpeed = 0
verticalSpeed = 0
zSpeed = 0
}
sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
checksum := uint8(sum % 256)
return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
cmdCode,
horizontalSpeed,
verticalSpeed,
zSpeed,
checksum,
), nil
}

198
pkg/service/ptz_test.go Normal file
View File

@ -0,0 +1,198 @@
package service
import (
"testing"
)
func TestGetPTZSpeed(t *testing.T) {
tests := []struct {
name string
speed string
expected uint8
}{
{"Speed 1", "1", 25},
{"Speed 2", "2", 50},
{"Speed 3", "3", 75},
{"Speed 4", "4", 100},
{"Speed 5", "5", 125},
{"Speed 6", "6", 150},
{"Speed 7", "7", 175},
{"Speed 8", "8", 200},
{"Speed 9", "9", 225},
{"Speed 10", "10", 255},
{"Invalid speed", "invalid", 125}, // 默认速度
{"Empty speed", "", 125}, // 默认速度
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getPTZSpeed(tt.speed)
if result != tt.expected {
t.Errorf("getPTZSpeed(%s) = %d, expected %d", tt.speed, result, tt.expected)
}
})
}
}
func TestToPTZCmd(t *testing.T) {
tests := []struct {
name string
cmdName string
speed string
expectError bool
checkPrefix bool
}{
{"Stop command", "stop", "5", false, true},
{"Right command", "right", "5", false, true},
{"Left command", "left", "5", false, true},
{"Up command", "up", "5", false, true},
{"Down command", "down", "5", false, true},
{"Up-right command", "upright", "5", false, true},
{"Up-left command", "upleft", "5", false, true},
{"Down-right command", "downright", "5", false, true},
{"Down-left command", "downleft", "5", false, true},
{"Zoom in command", "zoomin", "5", false, true},
{"Zoom out command", "zoomout", "5", false, true},
{"Invalid command", "invalid", "5", true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := toPTZCmd(tt.cmdName, tt.speed)
if tt.expectError {
if err == nil {
t.Errorf("Expected error for command %s, got nil", tt.cmdName)
}
return
}
if err != nil {
t.Errorf("Unexpected error for command %s: %v", tt.cmdName, err)
return
}
// 验证结果格式
if len(result) != 16 { // A50F01 + 5对字节 = 16个字符
t.Errorf("Expected result length 16, got %d for command %s", len(result), tt.cmdName)
}
// 验证前缀
if tt.checkPrefix && result[:6] != "A50F01" {
t.Errorf("Expected prefix 'A50F01', got '%s' for command %s", result[:6], tt.cmdName)
}
})
}
}
func TestToPTZCmdSpecificCases(t *testing.T) {
// 测试停止命令
t.Run("Stop command details", func(t *testing.T) {
result, err := toPTZCmd("stop", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Stop 命令码是 0速度应该都是 0
// A50F01 00 00 00 00 checksum
if result[:8] != "A50F0100" {
t.Errorf("Stop command should start with A50F0100, got %s", result[:8])
}
})
// 测试右移命令
t.Run("Right command details", func(t *testing.T) {
result, err := toPTZCmd("right", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Right 命令码是 1水平速度应该是 125 (0x7D)
// A50F01 01 7D 00 00 checksum
if result[:8] != "A50F0101" {
t.Errorf("Right command should start with A50F0101, got %s", result[:8])
}
})
// 测试上移命令
t.Run("Up command details", func(t *testing.T) {
result, err := toPTZCmd("up", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Up 命令码是 8垂直速度应该是 125 (0x7D)
// A50F01 08 00 7D 00 checksum
if result[:8] != "A50F0108" {
t.Errorf("Up command should start with A50F0108, got %s", result[:8])
}
})
// 测试缩放命令
t.Run("Zoom in command details", func(t *testing.T) {
result, err := toPTZCmd("zoomin", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Zoom in 命令码是 16 (0x10)
// A50F01 10 00 00 XX checksum (XX 是速度左移4位)
if result[:8] != "A50F0110" {
t.Errorf("Zoom in command should start with A50F0110, got %s", result[:8])
}
})
}
func TestToPTZCmdWithDifferentSpeeds(t *testing.T) {
speeds := []string{"1", "5", "10"}
for _, speed := range speeds {
t.Run("Right with speed "+speed, func(t *testing.T) {
result, err := toPTZCmd("right", speed)
if err != nil {
t.Errorf("Unexpected error with speed %s: %v", speed, err)
}
if len(result) != 16 {
t.Errorf("Expected length 16, got %d", len(result))
}
})
}
}
func TestPTZCmdMap(t *testing.T) {
// 验证所有预定义的命令都存在
expectedCommands := []string{
"stop", "right", "left", "down", "downright", "downleft",
"up", "upright", "upleft", "zoomin", "zoomout",
}
for _, cmd := range expectedCommands {
t.Run("Command exists: "+cmd, func(t *testing.T) {
if _, ok := ptzCmdMap[cmd]; !ok {
t.Errorf("Command %s not found in ptzCmdMap", cmd)
}
})
}
}
func TestPTZSpeedMap(t *testing.T) {
// 验证速度映射的正确性
expectedSpeeds := map[string]uint8{
"1": 25,
"2": 50,
"3": 75,
"4": 100,
"5": 125,
"6": 150,
"7": 175,
"8": 200,
"9": 225,
"10": 255,
}
for speed, expectedValue := range expectedSpeeds {
t.Run("Speed mapping: "+speed, func(t *testing.T) {
if value, ok := ptzSpeedMap[speed]; !ok {
t.Errorf("Speed %s not found in ptzSpeedMap", speed)
} else if value != expectedValue {
t.Errorf("Speed %s expected value %d, got %d", speed, expectedValue, value)
}
})
}
}

View File

@ -1,68 +1,68 @@
package stack
import (
"github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors"
)
type OutboundConfig struct {
Transport string
Via string
From string
To string
}
func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
if len(conf.From) != 20 || len(conf.To) != 20 {
return nil, errors.Errorf("From or To length is not 20")
}
dest := conf.Via
to := sip.Uri{User: conf.To, Host: conf.To[:10]}
from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
fromHeader.Params.Add("tag", sip.GenerateTagN(16))
req := sip.NewRequest(method, to)
req.AppendHeader(fromHeader)
req.AppendHeader(&sip.ToHeader{Address: to})
req.AppendHeader(&sip.ContactHeader{Address: *from})
req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
req.SetBody(body)
req.SetDestination(dest)
req.SetTransport(conf.Transport)
return req, nil
}
func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.REGISTER, nil, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Expires", "3600"))
return req, nil
}
func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.INVITE, body, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
req.AppendHeader(sip.NewHeader("Subject", subject))
return req, nil
}
func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.MESSAGE, body, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
return req, nil
}
package stack
import (
"github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors"
)
type OutboundConfig struct {
Transport string
Via string
From string
To string
}
func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
if len(conf.From) != 20 || len(conf.To) != 20 {
return nil, errors.Errorf("From or To length is not 20")
}
dest := conf.Via
to := sip.Uri{User: conf.To, Host: conf.To[:10]}
from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
fromHeader.Params.Add("tag", sip.GenerateTagN(16))
req := sip.NewRequest(method, to)
req.AppendHeader(fromHeader)
req.AppendHeader(&sip.ToHeader{Address: to})
req.AppendHeader(&sip.ContactHeader{Address: *from})
req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
req.SetBody(body)
req.SetDestination(dest)
req.SetTransport(conf.Transport)
return req, nil
}
func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.REGISTER, nil, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Expires", "3600"))
return req, nil
}
func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.INVITE, body, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
req.AppendHeader(sip.NewHeader("Subject", subject))
return req, nil
}
func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.MESSAGE, body, conf)
if err != nil {
return nil, err
}
req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
return req, nil
}

View File

@ -0,0 +1,292 @@
package stack
import (
"testing"
"github.com/emiago/sipgo/sip"
)
func TestNewRequest(t *testing.T) {
tests := []struct {
name string
method sip.RequestMethod
body []byte
conf OutboundConfig
wantErr bool
errMsg string
checkFunc func(*testing.T, *sip.Request)
}{
{
name: "Valid REGISTER request",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.REGISTER {
t.Errorf("Expected method REGISTER, got %v", req.Method)
}
if req.Destination() != "192.168.1.100:5060" {
t.Errorf("Expected destination 192.168.1.100:5060, got %v", req.Destination())
}
if req.Transport() != "udp" {
t.Errorf("Expected transport udp, got %v", req.Transport())
}
},
},
{
name: "Valid INVITE request with body",
method: sip.INVITE,
body: []byte("v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\n"),
conf: OutboundConfig{
Transport: "tcp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.INVITE {
t.Errorf("Expected method INVITE, got %v", req.Method)
}
if req.Body() == nil {
t.Error("Expected body to be set")
}
},
},
{
name: "Invalid From length - too short",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "123456789",
To: "34020000001110000001",
},
wantErr: true,
errMsg: "From or To length is not 20",
},
{
name: "Invalid To length - too long",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "340200000011100000012345",
},
wantErr: true,
errMsg: "From or To length is not 20",
},
{
name: "Valid MESSAGE request",
method: sip.MESSAGE,
body: []byte("<?xml version=\"1.0\"?><Query></Query>"),
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.MESSAGE {
t.Errorf("Expected method MESSAGE, got %v", req.Method)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.method, tt.body, tt.conf)
if tt.wantErr {
if err == nil {
t.Errorf("Expected error but got nil")
} else if tt.errMsg != "" && err.Error() != tt.errMsg {
t.Errorf("Expected error message '%s', got '%s'", tt.errMsg, err.Error())
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if req == nil {
t.Error("Expected request to be non-nil")
return
}
// Run custom checks
if tt.checkFunc != nil {
tt.checkFunc(t, req)
}
// Common checks
if req.From() == nil {
t.Error("Expected From header to be set")
}
if req.To() == nil {
t.Error("Expected To header to be set")
}
if req.Contact() == nil {
t.Error("Expected Contact header to be set")
}
})
}
}
func TestNewRegisterRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.REGISTER {
t.Errorf("Expected method REGISTER, got %v", req.Method)
}
// Check for Expires header
expires := req.GetHeader("Expires")
if expires == nil {
t.Error("Expected Expires header to be set")
} else if expires.Value() != "3600" {
t.Errorf("Expected Expires value 3600, got %v", expires.Value())
}
}
func TestNewInviteRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
body := []byte("v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\ns=Play\r\n")
subject := "34020000001320000001:0,34020000001110000001:0"
req, err := NewInviteRequest(body, subject, conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.INVITE {
t.Errorf("Expected method INVITE, got %v", req.Method)
}
// Check for Content-Type header
contentType := req.GetHeader("Content-Type")
if contentType == nil {
t.Error("Expected Content-Type header to be set")
} else if contentType.Value() != "application/sdp" {
t.Errorf("Expected Content-Type value application/sdp, got %v", contentType.Value())
}
// Check for Subject header
subjectHeader := req.GetHeader("Subject")
if subjectHeader == nil {
t.Error("Expected Subject header to be set")
} else if subjectHeader.Value() != subject {
t.Errorf("Expected Subject value %s, got %v", subject, subjectHeader.Value())
}
// Check body
if req.Body() == nil {
t.Error("Expected body to be set")
}
}
func TestNewMessageRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
body := []byte("<?xml version=\"1.0\"?><Query><CmdType>Catalog</CmdType></Query>")
req, err := NewMessageRequest(body, conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.MESSAGE {
t.Errorf("Expected method MESSAGE, got %v", req.Method)
}
// Check for Content-Type header
contentType := req.GetHeader("Content-Type")
if contentType == nil {
t.Error("Expected Content-Type header to be set")
} else if contentType.Value() != "Application/MANSCDP+xml" {
t.Errorf("Expected Content-Type value Application/MANSCDP+xml, got %v", contentType.Value())
}
// Check body
if req.Body() == nil {
t.Error("Expected body to be set")
}
}
func TestNewRequestWithInvalidConfig(t *testing.T) {
tests := []struct {
name string
conf OutboundConfig
}{
{
name: "Empty From",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "",
To: "34020000001110000001",
},
},
{
name: "Empty To",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "",
},
},
{
name: "From length 19",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "3402000000132000001",
To: "34020000001110000001",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewRequest(sip.REGISTER, nil, tt.conf)
if err == nil {
t.Error("Expected error but got nil")
}
})
}
}

View File

@ -1,41 +1,41 @@
package stack
import (
"fmt"
"time"
"github.com/emiago/sipgo/sip"
)
const TIME_LAYOUT = "2024-01-01T00:00:00"
const EXPIRES_TIME = 3600
func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := sip.NewResponseFromRequest(req, code, reason, nil)
newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
newTo.Params.Add("tag", sip.GenerateTagN(10))
resp.ReplaceHeader(newTo)
resp.RemoveHeader("Allow")
return resp
}
func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := newResponse(req, code, reason)
expires := sip.ExpiresHeader(EXPIRES_TIME)
resp.AppendHeader(&expires)
resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
return resp
}
func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
resp := newResponse(req, code, reason)
resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
return resp
}
package stack
import (
"fmt"
"time"
"github.com/emiago/sipgo/sip"
)
const TIME_LAYOUT = "2024-01-01T00:00:00"
const EXPIRES_TIME = 3600
func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := sip.NewResponseFromRequest(req, code, reason, nil)
newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
newTo.Params.Add("tag", sip.GenerateTagN(10))
resp.ReplaceHeader(newTo)
resp.RemoveHeader("Allow")
return resp
}
func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := newResponse(req, code, reason)
expires := sip.ExpiresHeader(EXPIRES_TIME)
resp.AppendHeader(&expires)
resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
return resp
}
func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
resp := newResponse(req, code, reason)
resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
return resp
}

View File

@ -0,0 +1,296 @@
package stack
import (
"testing"
"github.com/emiago/sipgo/sip"
)
func TestNewRegisterResponse(t *testing.T) {
// Create a test request first - we need a properly initialized request
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
}{
{
name: "200 OK response",
code: sip.StatusOK,
reason: "OK",
},
{
name: "401 Unauthorized response",
code: sip.StatusUnauthorized,
reason: "Unauthorized",
},
{
name: "403 Forbidden response",
code: sip.StatusForbidden,
reason: "Forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := NewRegisterResponse(req, tt.code, tt.reason)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check for Expires header
expires := resp.GetHeader("Expires")
if expires == nil {
t.Error("Expected Expires header to be set")
} else if expires.Value() != "3600" {
t.Errorf("Expected Expires value 3600, got %v", expires.Value())
}
// Check for Date header
date := resp.GetHeader("Date")
if date == nil {
t.Error("Expected Date header to be set")
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
}
// Check that Allow header is removed
allow := resp.GetHeader("Allow")
if allow != nil {
t.Error("Expected Allow header to be removed")
}
})
}
}
func TestNewUnauthorizedResponse(t *testing.T) {
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
nonce string
realm string
}{
{
name: "401 Unauthorized with nonce and realm",
code: sip.StatusUnauthorized,
reason: "Unauthorized",
nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093",
realm: "3402000000",
},
{
name: "407 Proxy Authentication Required",
code: sip.StatusProxyAuthRequired,
reason: "Proxy Authentication Required",
nonce: "abc123def456",
realm: "proxy.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := NewUnauthorizedResponse(req, tt.code, tt.reason, tt.nonce, tt.realm)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check for WWW-Authenticate header
wwwAuth := resp.GetHeader("WWW-Authenticate")
if wwwAuth == nil {
t.Error("Expected WWW-Authenticate header to be set")
} else {
authValue := wwwAuth.Value()
// Check that it contains the nonce
if len(authValue) == 0 {
t.Error("Expected WWW-Authenticate header to have a value")
}
// The value should contain Digest, realm, nonce, and algorithm
expectedSubstrings := []string{"Digest", "realm=", "nonce=", "algorithm=MD5"}
for _, substr := range expectedSubstrings {
if len(authValue) > 0 && !contains(authValue, substr) {
t.Errorf("Expected WWW-Authenticate to contain '%s', got '%s'", substr, authValue)
}
}
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
}
})
}
}
func TestNewResponse(t *testing.T) {
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRequest(sip.INVITE, nil, conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
}{
{
name: "100 Trying",
code: sip.StatusTrying,
reason: "Trying",
},
{
name: "180 Ringing",
code: sip.StatusRinging,
reason: "Ringing",
},
{
name: "200 OK",
code: sip.StatusOK,
reason: "OK",
},
{
name: "404 Not Found",
code: sip.StatusNotFound,
reason: "Not Found",
},
{
name: "500 Server Internal Error",
code: sip.StatusInternalServerError,
reason: "Server Internal Error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := newResponse(req, tt.code, tt.reason)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
// Check tag length is 10
if len(tag) != 10 {
t.Errorf("Expected tag length to be 10, got %d", len(tag))
}
}
// Check that Allow header is removed
allow := resp.GetHeader("Allow")
if allow != nil {
t.Error("Expected Allow header to be removed")
}
})
}
}
func TestResponseConstants(t *testing.T) {
if TIME_LAYOUT != "2024-01-01T00:00:00" {
t.Errorf("Expected TIME_LAYOUT to be '2024-01-01T00:00:00', got '%s'", TIME_LAYOUT)
}
if EXPIRES_TIME != 3600 {
t.Errorf("Expected EXPIRES_TIME to be 3600, got %d", EXPIRES_TIME)
}
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@ -1,122 +1,122 @@
package service
import (
"context"
"fmt"
"log/slog"
"github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/srs-sip/pkg/config"
"github.com/ossrs/srs-sip/pkg/service/stack"
)
const (
UserAgent = "SRS-SIP/1.0"
)
type UAC struct {
*Cascade
SN uint32
LocalIP string
}
func NewUac() *UAC {
ip, err := config.GetLocalIP()
if err != nil {
slog.Error("get local ip failed", "error", err)
return nil
}
c := &UAC{
Cascade: &Cascade{},
LocalIP: ip,
}
return c
}
func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
var err error
c.ctx = context.Background()
c.conf = r0.(*config.MainConfig)
if agent == nil {
ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
if err != nil {
return err
}
agent = ua
}
c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
if err != nil {
return err
}
c.sipSvr, err = sipgo.NewServer(agent)
if err != nil {
return err
}
c.sipSvr.OnInvite(c.onInvite)
c.sipSvr.OnBye(c.onBye)
c.sipSvr.OnMessage(c.onMessage)
go c.doRegister()
return nil
}
func (c *UAC) Stop() {
// TODO: 断开所有当前连接
c.sipCli.Close()
c.sipSvr.Close()
}
func (c *UAC) doRegister() error {
r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
From: "34020000001110000001",
To: "34020000002000000001",
Transport: "UDP",
Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
})
tx, err := c.sipCli.TransactionRequest(c.ctx, r)
if err != nil {
return errors.Wrapf(err, "transaction request error")
}
rs, _ := c.getResponse(tx)
slog.Info("register response", "response", rs.String())
return nil
}
func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
switch req.Method {
case "INVITE":
c.onInvite(req, tx)
}
}
func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onInvite")
}
func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onBye")
}
func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onMessage", "request", req.String())
}
func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
select {
case <-tx.Done():
return nil, fmt.Errorf("transaction died")
case res := <-tx.Responses():
return res, nil
}
}
package service
import (
"context"
"fmt"
"log/slog"
"github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/srs-sip/pkg/config"
"github.com/ossrs/srs-sip/pkg/service/stack"
)
const (
UserAgent = "SRS-SIP/1.0"
)
type UAC struct {
*Cascade
SN uint32
LocalIP string
}
func NewUac() *UAC {
ip, err := config.GetLocalIP()
if err != nil {
slog.Error("get local ip failed", "error", err)
return nil
}
c := &UAC{
Cascade: &Cascade{},
LocalIP: ip,
}
return c
}
func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
var err error
c.ctx = context.Background()
c.conf = r0.(*config.MainConfig)
if agent == nil {
ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
if err != nil {
return err
}
agent = ua
}
c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
if err != nil {
return err
}
c.sipSvr, err = sipgo.NewServer(agent)
if err != nil {
return err
}
c.sipSvr.OnInvite(c.onInvite)
c.sipSvr.OnBye(c.onBye)
c.sipSvr.OnMessage(c.onMessage)
go c.doRegister()
return nil
}
func (c *UAC) Stop() {
// TODO: 断开所有当前连接
c.sipCli.Close()
c.sipSvr.Close()
}
func (c *UAC) doRegister() error {
r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
From: "34020000001110000001",
To: "34020000002000000001",
Transport: "UDP",
Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
})
tx, err := c.sipCli.TransactionRequest(c.ctx, r)
if err != nil {
return errors.Wrapf(err, "transaction request error")
}
rs, _ := c.getResponse(tx)
slog.Info("register response", "response", rs.String())
return nil
}
func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
switch req.Method {
case "INVITE":
c.onInvite(req, tx)
}
}
func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onInvite")
}
func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onBye")
}
func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onMessage", "request", req.String())
}
func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
select {
case <-tx.Done():
return nil, fmt.Errorf("transaction died")
case res := <-tx.Responses():
return res, nil
}
}

View File

@ -1,201 +1,201 @@
package utils
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
)
var logLevelMap = map[string]slog.Level{
"debug": slog.LevelDebug,
"info": slog.LevelInfo,
"warn": slog.LevelWarn,
"error": slog.LevelError,
}
// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
type CustomFormatHandler struct {
mu sync.Mutex
w io.Writer
level slog.Level
attrs []slog.Attr
groups []string
}
// NewCustomFormatHandler 创建一个新的自定义格式处理器
func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
if opts == nil {
opts = &slog.HandlerOptions{}
}
// 获取日志级别如果opts.Level是nil则默认为Info
var level slog.Level
if opts.Level != nil {
level = opts.Level.Level()
} else {
level = slog.LevelInfo
}
return &CustomFormatHandler{
w: w,
level: level,
}
}
// Enabled 实现 slog.Handler 接口
func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= h.level
}
// Handle 实现 slog.Handler 接口,以自定义格式输出日志
func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
// 时间格式
timeStr := record.Time.Format("2006-01-02 15:04:05.000")
// 日志级别
var levelStr string
switch {
case record.Level >= slog.LevelError:
levelStr = "ERROR"
case record.Level >= slog.LevelWarn:
levelStr = "WARN "
case record.Level >= slog.LevelInfo:
levelStr = "INFO "
default:
levelStr = "DEBUG"
}
// 构建日志行
logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
// 处理其他属性
var attrs []string
record.Attrs(func(attr slog.Attr) bool {
attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
return true
})
if len(attrs) > 0 {
logLine += " " + strings.Join(attrs, " ")
}
// 写入日志
_, err := fmt.Fprintln(h.w, logLine)
return err
}
// WithAttrs 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
h2 := *h
h2.attrs = append(h.attrs[:], attrs...)
return &h2
}
// WithGroup 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
h2 := *h
h2.groups = append(h.groups[:], name)
return &h2
}
// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
type MultiHandler struct {
handlers []slog.Handler
}
// Enabled 实现 slog.Handler 接口
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
// 如果任何一个处理器启用了该级别,则返回 true
for _, handler := range h.handlers {
if handler.Enabled(ctx, level) {
return true
}
}
return false
}
// Handle 实现 slog.Handler 接口
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// 将记录发送到所有处理器
for _, handler := range h.handlers {
if handler.Enabled(ctx, record.Level) {
if err := handler.Handle(ctx, record); err != nil {
return err
}
}
}
return nil
}
// WithAttrs 实现 slog.Handler 接口
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers {
newHandlers[i] = handler.WithAttrs(attrs)
}
return &MultiHandler{handlers: newHandlers}
}
// WithGroup 实现 slog.Handler 接口
func (h *MultiHandler) WithGroup(name string) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers {
newHandlers[i] = handler.WithGroup(name)
}
return &MultiHandler{handlers: newHandlers}
}
// SetupLogger 设置日志输出
func SetupLogger(logLevel string, logFile string) error {
// 创建标准错误输出的处理器,使用自定义格式
stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevelMap[logLevel],
})
// 如果没有指定日志文件,则仅使用标准错误处理器
if logFile == "" {
slog.SetDefault(slog.New(stdHandler))
return nil
}
// 确保日志文件所在目录存在
logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0755); err != nil {
return err
}
// 打开日志文件,如果不存在则创建,追加写入模式
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
// 创建文件输出的处理器,使用自定义格式
fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
Level: logLevelMap[logLevel],
})
// 创建多输出处理器
multiHandler := &MultiHandler{
handlers: []slog.Handler{stdHandler, fileHandler},
}
// 设置全局日志处理器
slog.SetDefault(slog.New(multiHandler))
return nil
}
// InitDefaultLogger 初始化默认日志处理器
func InitDefaultLogger(level slog.Level) {
slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})))
}
package utils
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
)
var logLevelMap = map[string]slog.Level{
"debug": slog.LevelDebug,
"info": slog.LevelInfo,
"warn": slog.LevelWarn,
"error": slog.LevelError,
}
// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
type CustomFormatHandler struct {
mu sync.Mutex
w io.Writer
level slog.Level
attrs []slog.Attr
groups []string
}
// NewCustomFormatHandler 创建一个新的自定义格式处理器
func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
if opts == nil {
opts = &slog.HandlerOptions{}
}
// 获取日志级别如果opts.Level是nil则默认为Info
var level slog.Level
if opts.Level != nil {
level = opts.Level.Level()
} else {
level = slog.LevelInfo
}
return &CustomFormatHandler{
w: w,
level: level,
}
}
// Enabled 实现 slog.Handler 接口
func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= h.level
}
// Handle 实现 slog.Handler 接口,以自定义格式输出日志
func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
// 时间格式
timeStr := record.Time.Format("2006-01-02 15:04:05.000")
// 日志级别
var levelStr string
switch {
case record.Level >= slog.LevelError:
levelStr = "ERROR"
case record.Level >= slog.LevelWarn:
levelStr = "WARN "
case record.Level >= slog.LevelInfo:
levelStr = "INFO "
default:
levelStr = "DEBUG"
}
// 构建日志行
logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
// 处理其他属性
var attrs []string
record.Attrs(func(attr slog.Attr) bool {
attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
return true
})
if len(attrs) > 0 {
logLine += " " + strings.Join(attrs, " ")
}
// 写入日志
_, err := fmt.Fprintln(h.w, logLine)
return err
}
// WithAttrs 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
h2 := *h
h2.attrs = append(h.attrs[:], attrs...)
return &h2
}
// WithGroup 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
h2 := *h
h2.groups = append(h.groups[:], name)
return &h2
}
// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
type MultiHandler struct {
handlers []slog.Handler
}
// Enabled 实现 slog.Handler 接口
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
// 如果任何一个处理器启用了该级别,则返回 true
for _, handler := range h.handlers {
if handler.Enabled(ctx, level) {
return true
}
}
return false
}
// Handle 实现 slog.Handler 接口
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// 将记录发送到所有处理器
for _, handler := range h.handlers {
if handler.Enabled(ctx, record.Level) {
if err := handler.Handle(ctx, record); err != nil {
return err
}
}
}
return nil
}
// WithAttrs 实现 slog.Handler 接口
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers {
newHandlers[i] = handler.WithAttrs(attrs)
}
return &MultiHandler{handlers: newHandlers}
}
// WithGroup 实现 slog.Handler 接口
func (h *MultiHandler) WithGroup(name string) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers {
newHandlers[i] = handler.WithGroup(name)
}
return &MultiHandler{handlers: newHandlers}
}
// SetupLogger 设置日志输出
func SetupLogger(logLevel string, logFile string) error {
// 创建标准错误输出的处理器,使用自定义格式
stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevelMap[logLevel],
})
// 如果没有指定日志文件,则仅使用标准错误处理器
if logFile == "" {
slog.SetDefault(slog.New(stdHandler))
return nil
}
// 确保日志文件所在目录存在
logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0755); err != nil {
return err
}
// 打开日志文件,如果不存在则创建,追加写入模式
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
// 创建文件输出的处理器,使用自定义格式
fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
Level: logLevelMap[logLevel],
})
// 创建多输出处理器
multiHandler := &MultiHandler{
handlers: []slog.Handler{stdHandler, fileHandler},
}
// 设置全局日志处理器
slog.SetDefault(slog.New(multiHandler))
return nil
}
// InitDefaultLogger 初始化默认日志处理器
func InitDefaultLogger(level slog.Level) {
slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})))
}

197
pkg/utils/utils_test.go Normal file
View File

@ -0,0 +1,197 @@
package utils
import (
"testing"
)
func TestGenRandomNumber(t *testing.T) {
tests := []struct {
name string
length int
}{
{"Generate 1 digit", 1},
{"Generate 5 digits", 5},
{"Generate 9 digits", 9},
{"Generate 10 digits", 10},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenRandomNumber(tt.length)
// 验证长度
if len(result) != tt.length {
t.Errorf("Expected length %d, got %d", tt.length, len(result))
}
// 验证所有字符都是数字
for i, c := range result {
if c < '0' || c > '9' {
t.Errorf("Character at position %d is not a digit: %c", i, c)
}
}
})
}
}
func TestGenRandomNumberUniqueness(t *testing.T) {
// 生成多个随机数,验证它们不完全相同(虽然理论上可能相同,但概率极低)
results := make(map[string]bool)
iterations := 100
length := 10
for i := 0; i < iterations; i++ {
result := GenRandomNumber(length)
results[result] = true
}
// 至少应该有一些不同的值不太可能100次都生成相同的10位数
if len(results) < 50 {
t.Errorf("Expected at least 50 unique values out of %d iterations, got %d", iterations, len(results))
}
}
func TestCreateSSRC(t *testing.T) {
tests := []struct {
name string
isLive bool
expected byte
}{
{"Live stream SSRC", true, '0'},
{"Non-live stream SSRC", false, '1'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ssrc := CreateSSRC(tt.isLive)
// 验证长度为10
if len(ssrc) != 10 {
t.Errorf("Expected SSRC length 10, got %d", len(ssrc))
}
// 验证第一个字符
if ssrc[0] != tt.expected {
t.Errorf("Expected first character '%c', got '%c'", tt.expected, ssrc[0])
}
// 验证所有字符都是数字
for i, c := range ssrc {
if c < '0' || c > '9' {
t.Errorf("Character at position %d is not a digit: %c", i, c)
}
}
})
}
}
func TestCreateSSRCUniqueness(t *testing.T) {
// 测试生成的 SSRC 具有唯一性
results := make(map[string]bool)
iterations := 100
for i := 0; i < iterations; i++ {
ssrc := CreateSSRC(true)
results[ssrc] = true
}
// 应该有很多不同的值
if len(results) < 50 {
t.Errorf("Expected at least 50 unique SSRCs out of %d iterations, got %d", iterations, len(results))
}
}
func TestIsVideoChannel(t *testing.T) {
tests := []struct {
name string
channelID string
expected bool
}{
{
name: "Video channel type 131",
channelID: "34020000001310000001",
expected: true,
},
{
name: "Video channel type 132",
channelID: "34020000001320000001",
expected: true,
},
{
name: "Audio channel type 137",
channelID: "34020000001370000001",
expected: false,
},
{
name: "Alarm channel type 134",
channelID: "34020000001340000001",
expected: false,
},
{
name: "Other device type",
channelID: "34020000001110000001",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsVideoChannel(tt.channelID)
if result != tt.expected {
t.Errorf("IsVideoChannel(%s) = %v, expected %v", tt.channelID, result, tt.expected)
}
})
}
}
func TestGetSessionName(t *testing.T) {
tests := []struct {
name string
playType int
expected string
}{
{"Live play", 0, "Play"},
{"Playback", 1, "Playback"},
{"Download", 2, "Download"},
{"Talk", 3, "Talk"},
{"Unknown type", 99, "Play"},
{"Negative type", -1, "Play"},
{"Type 4", 4, "Play"},
{"Type 5", 5, "Play"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSessionName(tt.playType)
if result != tt.expected {
t.Errorf("GetSessionName(%d) = %s, expected %s", tt.playType, result, tt.expected)
}
})
}
}
func TestGenRandomNumberZeroLength(t *testing.T) {
result := GenRandomNumber(0)
if len(result) != 0 {
t.Errorf("Expected empty string for length 0, got %s", result)
}
}
func TestCreateSSRCBothTypes(t *testing.T) {
// Test both live and non-live in one test
liveSSRC := CreateSSRC(true)
nonLiveSSRC := CreateSSRC(false)
if liveSSRC[0] != '0' {
t.Errorf("Live SSRC should start with '0', got '%c'", liveSSRC[0])
}
if nonLiveSSRC[0] != '1' {
t.Errorf("Non-live SSRC should start with '1', got '%c'", nonLiveSSRC[0])
}
// They should be different (with very high probability)
if liveSSRC == nonLiveSSRC {
t.Error("Live and non-live SSRCs should be different")
}
}

20
run/conf/config.yaml Normal file
View File

@ -0,0 +1,20 @@
# 通用配置
common:
# [debug, info, warn, error]
log-level: "info"
log-file: "logs/srs-sip.log"
# GB28181配置
gb28181:
serial: "34020000002000000001"
realm: "3402000000"
host: "0.0.0.0"
port: 5060
auth:
enable: false
password: "123456"
# HTTP服务配置
http:
listen: 8025
dir: ./html

60
run/srs/conf/srs.conf Normal file
View File

@ -0,0 +1,60 @@
listen 1935;
max_connections 1000;
# For docker, please use docker logs to manage the logs of SRS.
# See https://docs.docker.com/config/containers/logging/
srs_log_tank console;
daemon off;
disable_daemon_for_docker off;
http_api {
enabled on;
listen 1985;
raw_api {
enabled on;
allow_reload on;
}
}
http_server {
enabled on;
listen 8080;
dir ./objs/nginx/html;
}
stream_caster {
enabled on;
caster gb28181;
output rtmp://127.0.0.1/live/[stream];
listen 9000;
sip {
enabled off;
}
}
rtc_server {
enabled on;
listen 8000; # UDP port
# @see https://github.com/ossrs/srs/wiki/v4_CN_WebRTC#config-candidate
candidate $CANDIDATE;
# Disable for Oryx.
use_auto_detect_network_ip off;
api_as_candidates off;
}
vhost __defaultVhost__ {
http_remux {
enabled on;
mount [vhost]/[app]/[stream].flv;
}
rtc {
enabled on;
nack on;
twcc on;
stun_timeout 30;
dtls_role passive;
# @see https://github.com/ossrs/srs/wiki/v4_CN_WebRTC#rtmp-to-rtc
rtmp_to_rtc on;
keep_bframe off;
# @see https://github.com/ossrs/srs/wiki/v4_CN_WebRTC#rtc-to-rtmp
rtc_to_rtmp on;
pli_for_rtmp 6.0;
}
}

View File

@ -1,30 +1,30 @@
package main
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181"
)
func main() {
ctx := context.Background()
var conf interface{}
conf = gb28181.Parse(ctx)
ctx, cancel := context.WithCancel(ctx)
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
for sig := range sigs {
logger.Wf(ctx, "Quit for signal %v", sig)
cancel()
}
}()
gb28181.Run(ctx, conf)
}
package main
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181"
)
func main() {
ctx := context.Background()
var conf interface{}
conf = gb28181.Parse(ctx)
ctx, cancel := context.WithCancel(ctx)
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
for sig := range sigs {
logger.Wf(ctx, "Quit for signal %v", sig)
cancel()
}
}()
gb28181.Run(ctx, conf)
}