Compare commits

..

15 Commits

Author SHA1 Message Date
e80164f5be docs: enhance README with detailed explanation of SIP and SDP architecture
Some checks failed
CodeQL / Analyze Go Code (go) (push) Has been cancelled
2026-01-13 15:06:52 +08:00
ff4ddfacba refactor: rename zoom commands to use camelCase in API and PTZ control 2026-01-13 14:50:26 +08:00
ffb3fc423e feat: add zoom_in and zoom_out commands to PTZ command map
note: the `API.md` said it's `zoom_in` and `zoom_out` with underscore, while the go ptzCmdMap said it's `zoomin` and `zoomout` with underscore; I choose to support both (or unify them?)
2026-01-13 14:46:57 +08:00
b4474a160d feat: add zoom controls to PtzControlPanel with zoom in and zoom out functionality 2026-01-13 14:34:47 +08:00
6fbbfb698a Add configuration files for SRS SIP and update README with Docker commands 2026-01-13 11:45:32 +08:00
42d018b854 Add Docker support and configuration for SRS SIP
- Created a new README_cross.md file with Docker build instructions.
- Updated srs.conf to include logging configuration options.
- Added docker-compose.yml to define the SRS SIP service with necessary ports and volume mappings.
- Introduced config.yaml for general and GB28181-specific configurations.
- Added initial srs.conf with settings for RTMP, HTTP API, and WebRTC support.
2026-01-13 11:41:07 +08:00
2aa65de911 security 2025-10-15 16:04:35 +08:00
35de09aeb6 ut 2025-10-15 15:35:41 +08:00
1178b974a1 codeql 2025-10-15 14:18:47 +08:00
156f07644d gofmt 2025-10-15 10:05:52 +08:00
d9709f61a5 codecov 2025-10-15 09:59:32 +08:00
59bc95ab21 update 2025-10-15 09:29:38 +08:00
4c7485f4ef unit test 2025-10-15 09:14:33 +08:00
b0fce4380f fix warn 2025-10-14 16:51:37 +08:00
a92d1624c5 ci 2025-10-14 16:48:21 +08:00
37 changed files with 8762 additions and 1232 deletions

22
.github/codeql/codeql-config.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: "CodeQL Config"
# 指定要扫描的路径
paths:
- pkg
- main
# 排除不需要扫描的路径
paths-ignore:
- '**/*_test.go'
- 'html/**'
- 'objs/**'
- 'vendor/**'
# 使用的查询套件
queries:
- uses: security-extended
- uses: security-and-quality
# 禁用默认查询(如果只想使用自定义查询)
# disable-default-queries: true

82
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,82 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build-and-test:
name: Build and Test
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Cache Node modules
uses: actions/cache@v4
with:
path: html/NextGB/node_modules
key: ${{ runner.os }}-node-${{ hashFiles('html/NextGB/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Download Go dependencies
run: go mod download
- name: Build Go application
run: make build
- name: Run Go tests
run: go test -v ./...
- name: Run Go tests with coverage
run: go test ./pkg/... -coverprofile=coverage.out -covermode=atomic
- name: Display coverage report
run: go tool cover -func=coverage.out
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
- name: Install Vue dependencies
run: make vue-install
- name: Build Vue application
run: make vue-build
- name: Upload build artifacts
uses: actions/upload-artifact@v4
if: success()
with:
name: srs-sip-build
path: |
objs/srs-sip
html/NextGB/dist/
retention-days: 7

52
.github/workflows/codeql.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
# 每周一凌晨2点运行
- cron: '0 2 * * 1'
jobs:
analyze:
name: Analyze Go Code
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
cache: true
# 初始化 CodeQL
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/codeql-config.yml
# 自动构建
- name: Autobuild
uses: github/codeql-action/autobuild@v3
# 执行 CodeQL 分析
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

2
.gitignore vendored
View File

@ -23,3 +23,5 @@
hs_err_pid* hs_err_pid*
objs objs
.idea .idea
run

View File

@ -1,16 +1,29 @@
# 引入SRS # 引入SRS
FROM ossrs/srs:v6.0.155 AS srs FROM ossrs/srs:v6.0.184 AS srs
# 前端构建阶段 # 前端构建阶段
FROM node:20-slim AS frontend-builder FROM node:20-slim AS frontend-builder
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY}
WORKDIR /app/frontend WORKDIR /app/frontend
COPY html/NextGB/package*.json ./ COPY html/NextGB/package*.json ./
RUN npm install # RUN npm config set registry http://mirrors.cloud.tencent.com/npm/ \
# && npm install
RUN npm install
COPY html/NextGB/ . COPY html/NextGB/ .
RUN npm run build RUN npm run build
# 后端构建阶段 # 后端构建阶段
FROM golang:1.23 AS backend-builder FROM golang:1.23 AS backend-builder
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY} \
GOPROXY=https://goproxy.cn,direct
WORKDIR /app WORKDIR /app
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
@ -19,11 +32,20 @@ RUN CGO_ENABLED=0 GOOS=linux go build -o /app/srs-sip main/main.go
# 最终运行阶段 # 最终运行阶段
FROM ubuntu:22.04 FROM ubuntu:22.04
ARG HTTP_PROXY=
ARG NO_PROXY=
ENV http_proxy=${HTTP_PROXY} \
https_proxy=${HTTP_PROXY} \
no_proxy=${NO_PROXY}
WORKDIR /usr/local WORKDIR /usr/local
# 设置时区 # 设置时区
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
RUN apt-get update && \ RUN sed -i \
-e 's@http://archive.ubuntu.com/ubuntu/@http://mirrors.ustc.edu.cn/ubuntu/@g' \
-e 's@http://security.ubuntu.com/ubuntu/@http://mirrors.ustc.edu.cn/ubuntu/@g' \
/etc/apt/sources.list && \
apt-get update && \
apt-get install -y ca-certificates tzdata supervisor && \ apt-get install -y ca-certificates tzdata supervisor && \
ln -fs /usr/share/zoneinfo/$TZ /etc/localtime && \ ln -fs /usr/share/zoneinfo/$TZ /etc/localtime && \
dpkg-reconfigure -f noninteractive tzdata && \ dpkg-reconfigure -f noninteractive tzdata && \
@ -60,7 +82,7 @@ stderr_logfile=/dev/stderr\n\
stderr_logfile_maxbytes=0\n\ stderr_logfile_maxbytes=0\n\
\n\ \n\
[program:srs-sip]\n\ [program:srs-sip]\n\
command=/usr/local/srs-sip/srs-sip\n\ command=/usr/local/srs-sip/srs-sip -c /usr/local/srs-sip/config.yaml\n\
directory=/usr/local/srs-sip\n\ directory=/usr/local/srs-sip\n\
autostart=true\n\ autostart=true\n\
autorestart=true\n\ autorestart=true\n\
@ -71,4 +93,4 @@ stderr_logfile_maxbytes=0" > /etc/supervisor/conf.d/supervisord.conf
EXPOSE 1935 5060 8025 9000 5060/udp 8000/udp EXPOSE 1935 5060 8025 9000 5060/udp 8000/udp
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@ -1,5 +1,11 @@
# SRS-SIP # SRS-SIP
[![CI](https://github.com/ossrs/srs-sip/actions/workflows/ci.yml/badge.svg)](https://github.com/ossrs/srs-sip/actions/workflows/ci.yml)
[![CodeQL](https://github.com/ossrs/srs-sip/actions/workflows/codeql.yml/badge.svg)](https://github.com/ossrs/srs-sip/actions/workflows/codeql.yml)
[![codecov](https://codecov.io/gh/ossrs/srs-sip/branch/main/graph/badge.svg)](https://codecov.io/gh/ossrs/srs-sip)
[![Go Report Card](https://goreportcard.com/badge/github.com/ossrs/srs-sip)](https://goreportcard.com/report/github.com/ossrs/srs-sip)
[![License](https://img.shields.io/github/license/ossrs/srs-sip)](https://github.com/ossrs/srs-sip/blob/main/LICENSE)
## Usage ## Usage
Pre-requisites: Pre-requisites:
@ -24,6 +30,29 @@ Run the program:
./objs/srs-sip -c conf/config.yaml ./objs/srs-sip -c conf/config.yaml
``` ```
## Testing
Run all tests:
```bash
go test -v ./pkg/...
```
Run tests with coverage:
```bash
go test ./pkg/... -coverprofile=coverage.out
go tool cover -func=coverage.out
```
For more details, see [Testing Guide](docs/TESTING.md).
## Security
This project uses CodeQL for automated security scanning. For more information about security practices and how to report vulnerabilities, see [Security Guide](docs/SECURITY.md).
## Docker
Use docker Use docker
``` ```
docker run -id -p 1985:1985 -p 5060:5060 -p 8025:8025 -p 9000:9000 -p 5060:5060/udp -p 8000:8000/udp --name srs-sip --env CANDIDATE=your_ip ossrs/srs-sip:alpha docker run -id -p 1985:1985 -p 5060:5060 -p 8025:8025 -p 9000:9000 -p 5060:5060/udp -p 8000:8000/udp --name srs-sip --env CANDIDATE=your_ip ossrs/srs-sip:alpha

78
README_cross.md Normal file
View File

@ -0,0 +1,78 @@
# note to me
```bash
docker compose build --network host
```
```
docker compose up -d --force-recreate
```
# TODO
- [ ] let user choose whether use mirror (use which mirror) when building Dockerfile
---
Based on the logs and the **GB/T 28181-2022** standard you provided, here is the explanation:
Yes, this is **SIP**, but the content *inside* the SIP message is **SDP (Session Description Protocol)**.
While XML is used for *control* (like PTZ), **SDP** is used for **Media Negotiation** (setting up the video stream).
### The Architecture from your logs
1. **SIP (The Envelope):** Starts the conversation ("I want to watch video").
2. **SDP (The Letter inside):** Describes technical details ("Send video to IP X, Port Y, using Format Z").
3. **RTP (The result):** After this SIP/SDP handshake finishes, the actual binary video stream (PS/H.264) starts flowing over a separate TCP/UDP connection.
### Breakdown of your Log
This log shows a **Real-time Live View** handshake.
#### 1. The Request (SRS Server -> Camera)
The Server asks the Camera to send video.
```ini
INVITE sip:34020000001320000001@3402000000 SIP/2.0
Content-Type: application/sdp
s=Play # "Play" = Real-time Live View (Standard 9.2.2.1)
c=IN IP4 192.168.2.184 # The Media Server IP
m=video 9000 TCP/RTP/AVP 96 # "Send video to my Port 9000 via TCP"
a=recvonly # "I will only receive, not send"
y=0911024252 # **GB/T 28181 Special**: The SSRC (Stream ID)
```
#### 2. The Response (Camera -> SRS Server)
The Camera agrees and tells the server its own details.
```ini
SIP/2.0 200 OK
Content-Type: application/sdp
c=IN IP4 192.168.2.64 # The Camera IP
m=video 15060 TCP/RTP/AVP 96 # "I am sending from Port 15060"
a=sendonly # "I will only send"
a=setup:active # "I will initiate the TCP connection to you"
y=0911024252 # Matches the SSRC provided
f=v/2/2560x1440/25/2/8192a/... # **GB/T 28181 Special**: Media Info
```
### Key Differences from Standard SDP
GB/T 28181 modifies standard SDP with two specific fields mandatory for this protocol:
1. **`y=` (SSRC)**:
* **Standard SDP:** Does not have a `y` line.
* **GB/T 28181:** Uses `y` to define the **SSRC** (Synchronization Source). This 10-digit number is crucial because it marks every binary video packet sent later. If the binary stream headers don't match this `y` value, the stream is rejected.
2. **`f=` (Media Info)**:
* **Standard SDP:** Does not have an `f` line.
* **GB/T 28181:** Uses `f` to describe video parameters. In your log: `v/2/2560x1440/25...` means:
* `v`: Video
* `2`: Coding format (likely H.264 or H.265 mapped)
* `2560x1440`: Resolution
* `25`: Frame rate
### Summary of Cooperation
1. **XML (SIP MESSAGE):** Used for "remote control" (PTZ, Query, Keepalive).
2. **SDP (SIP INVITE):** Used to *negotiate* the pipeline.
3. **Binary (RTP/PS):** The actual heavy video data that flows through the pipe created by the SDP.

View File

@ -3,6 +3,11 @@ max_connections 1000;
# For docker, please use docker logs to manage the logs of SRS. # For docker, please use docker logs to manage the logs of SRS.
# See https://docs.docker.com/config/containers/logging/ # See https://docs.docker.com/config/containers/logging/
srs_log_tank console; srs_log_tank console;
# srs_log_tank file;
# srs_log_file /var/log/srs/srs.log;
# ff_log_dir /var/log/srs;
daemon off; daemon off;
disable_daemon_for_docker off; disable_daemon_for_docker off;
http_api { http_api {
@ -57,4 +62,4 @@ vhost __defaultVhost__ {
rtc_to_rtmp on; rtc_to_rtmp on;
pli_for_rtmp 6.0; pli_for_rtmp 6.0;
} }
} }

View File

@ -277,8 +277,8 @@ SRS-SIP 是一个基于 GB28181 协议的视频监控系统,提供设备管理
- `down`: 向下 - `down`: 向下
- `left`: 向左 - `left`: 向左
- `right`: 向右 - `right`: 向右
- `zoom_in`: 放大 - `zoomin`: 放大
- `zoom_out`: 缩小 - `zoomout`: 缩小
- `stop`: 停止 - `stop`: 停止
- `speed`: 控制速度1-9 - `speed`: 控制速度1-9

4535
doc/GBT+28181-2022.md Normal file

File diff suppressed because it is too large Load Diff

33
docker-compose.yml Normal file
View File

@ -0,0 +1,33 @@
services:
srs-sip:
build:
context: .
network: host
args:
HTTP_PROXY: ${HTTP_PROXY:-}
NO_PROXY: "localhost,127.0.0.1,::1"
environment:
# CANDIDATE: ${CANDIDATE:-}
CANDIDATE: 192.168.2.184
TZ: "Asia/Shanghai"
volumes:
- ./run/conf/config.yaml:/usr/local/srs-sip/config.yaml:ro
- ./run/logs:/usr/local/srs-sip/logs
- ./run/srs/conf/srs.conf:/usr/local/srs/conf/srs.conf:ro
# use docker logs
- ./run/srs/logs:/var/log/srs/
# for recording
- ./run/data:/data
ports:
# SRS RTMP
- "1985:1985"
# SRS media ingest (GB28181 RTP/PS 等,取决于你的配置)
- "9000:9000"
# SRS WebRTC
- "8000:8000/udp"
# SIP
- "5060:5060"
- "5060:5060/udp"
# WebUI
- "8025:8025"
restart: unless-stopped

View File

@ -1,6 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed } from 'vue'
import { ArrowRight, VideoCamera } from '@element-plus/icons-vue' import { ArrowRight, VideoCamera, ZoomIn, ZoomOut } from '@element-plus/icons-vue'
import { import {
ArrowUp, ArrowUp,
ArrowDown, ArrowDown,
@ -158,6 +158,27 @@ const isDisabled = computed(() => !props.activeWindow)
<div class="direction-center"></div> <div class="direction-center"></div>
</div> </div>
</div> </div>
<div class="zoom-controls">
<el-button
class="zoom-btn"
:disabled="isDisabled"
@mousedown="handlePtzStart('zoomin')"
@mouseup="handlePtzStop"
@mouseleave="handlePtzStop"
>
<el-icon><ZoomIn /></el-icon>
</el-button>
<div class="zoom-label">变倍</div>
<el-button
class="zoom-btn"
:disabled="isDisabled"
@mousedown="handlePtzStart('zoomout')"
@mouseup="handlePtzStop"
@mouseleave="handlePtzStop"
>
<el-icon><ZoomOut /></el-icon>
</el-button>
</div>
<div class="speed-control"> <div class="speed-control">
<div class="speed-value">{{ speed }}</div> <div class="speed-value">{{ speed }}</div>
<el-slider <el-slider
@ -346,6 +367,48 @@ const isDisabled = computed(() => !props.activeWindow)
border-radius: 4px; border-radius: 4px;
} }
.zoom-controls {
display: flex;
flex-direction: column;
align-items: center;
gap: 6px;
height: 120px;
justify-content: center;
}
.zoom-btn {
--el-button-bg-color: var(--el-color-primary-light-8);
--el-button-border-color: var(--el-color-primary-light-5);
--el-button-hover-bg-color: var(--el-color-primary-light-7);
--el-button-hover-border-color: var(--el-color-primary-light-4);
--el-button-active-bg-color: var(--el-color-primary-light-5);
--el-button-active-border-color: var(--el-color-primary);
width: 36px;
height: 36px;
padding: 0;
margin: 0;
border-radius: 4px;
.el-icon {
font-size: 18px;
}
&:hover {
transform: scale(1.05);
}
&:active {
transform: scale(0.95);
}
}
.zoom-label {
font-size: 12px;
color: var(--el-text-color-secondary);
font-weight: 500;
}
.control-groups, .control-groups,
.control-group { .control-group {
display: none; display: none;

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path" "path"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
@ -95,9 +96,42 @@ func main() {
return return
} }
// 首先检查原始路径是否包含 ".." 以防止路径遍历攻击
if strings.Contains(r.URL.Path, "..") {
slog.Warn("potential path traversal attempt detected", "path", r.URL.Path)
http.Error(w, "Invalid path", http.StatusBadRequest)
return
}
// 清理路径
cleanPath := path.Clean(r.URL.Path)
// 检查请求的文件是否存在 // 检查请求的文件是否存在
filePath := path.Join(conf.Http.Dir, r.URL.Path) filePath := filepath.Join(conf.Http.Dir, cleanPath)
_, err := os.Stat(filePath)
// 确保最终路径在允许的目录内
absDir, err := filepath.Abs(conf.Http.Dir)
if err != nil {
slog.Error("failed to get absolute path of http dir", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
slog.Error("failed to get absolute path of file", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// 验证文件路径在允许的目录内
if !strings.HasPrefix(absFilePath, absDir) {
slog.Warn("path traversal attempt blocked", "requested", r.URL.Path, "resolved", absFilePath)
http.Error(w, "Access denied", http.StatusForbidden)
return
}
_, err = os.Stat(absFilePath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
// 如果文件不存在,返回 index.html // 如果文件不存在,返回 index.html
slog.Info("file not found, redirect to index", "path", r.URL.Path) slog.Info("file not found, redirect to index", "path", r.URL.Path)

247
main/main_test.go Normal file
View File

@ -0,0 +1,247 @@
package main
import (
"path"
"path/filepath"
"strings"
"testing"
)
// TestPathTraversalPrevention 测试路径遍历防护
func TestPathTraversalPrevention(t *testing.T) {
baseDir := "/var/www/html"
tests := []struct {
name string
inputPath string
shouldFail bool
reason string
}{
{
name: "Normal file",
inputPath: "/index.html",
shouldFail: false,
reason: "Normal file access should be allowed",
},
{
name: "Subdirectory file",
inputPath: "/css/style.css",
shouldFail: false,
reason: "Subdirectory access should be allowed",
},
{
name: "Deep subdirectory",
inputPath: "/js/lib/jquery.min.js",
shouldFail: false,
reason: "Deep subdirectory access should be allowed",
},
{
name: "Parent directory traversal",
inputPath: "/../etc/passwd",
shouldFail: true,
reason: "Parent directory traversal should be blocked",
},
{
name: "Double parent traversal",
inputPath: "/../../etc/passwd",
shouldFail: true,
reason: "Double parent traversal should be blocked",
},
{
name: "Multiple parent traversal",
inputPath: "/../../../etc/passwd",
shouldFail: true,
reason: "Multiple parent traversal should be blocked",
},
{
name: "Mixed path with parent",
inputPath: "/css/../../etc/passwd",
shouldFail: true,
reason: "Mixed path with parent should be blocked",
},
{
name: "Dot slash path",
inputPath: "/./index.html",
shouldFail: false,
reason: "Dot slash should be cleaned but allowed",
},
{
name: "Complex traversal",
inputPath: "/css/../js/../../../etc/passwd",
shouldFail: true,
reason: "Complex traversal should be blocked",
},
{
name: "Root path",
inputPath: "/",
shouldFail: false,
reason: "Root path should be allowed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟修复后的路径验证逻辑
// 首先检查原始路径是否包含 ".."
containsDoubleDotInOriginal := strings.Contains(tt.inputPath, "..")
// 如果原始路径包含 "..",直接阻止
if containsDoubleDotInOriginal {
if !tt.shouldFail {
t.Errorf("%s: Path contains '..' but should be allowed: %s", tt.reason, tt.inputPath)
}
t.Logf("Input: %s, Contains '..': true, Blocked: true (early check)", tt.inputPath)
return
}
// 清理路径
cleanPath := path.Clean(tt.inputPath)
// 构建文件路径
filePath := filepath.Join(baseDir, cleanPath)
// 获取绝对路径
absDir, err := filepath.Abs(baseDir)
if err != nil {
t.Fatalf("Failed to get absolute path of base dir: %v", err)
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
t.Fatalf("Failed to get absolute path of file: %v", err)
}
// 验证路径是否在允许的目录内
isOutsideBaseDir := !strings.HasPrefix(absFilePath, absDir)
// 判断是否应该被阻止
shouldBlock := isOutsideBaseDir
if tt.shouldFail && !shouldBlock {
t.Errorf("%s: Expected path to be blocked, but it was allowed. Path: %s, Clean: %s, Abs: %s",
tt.reason, tt.inputPath, cleanPath, absFilePath)
}
if !tt.shouldFail && shouldBlock {
t.Errorf("%s: Expected path to be allowed, but it was blocked. Path: %s, Clean: %s, Abs: %s",
tt.reason, tt.inputPath, cleanPath, absFilePath)
}
// 额外的日志信息用于调试
t.Logf("Input: %s, Clean: %s, Outside base: %v, Blocked: %v",
tt.inputPath, cleanPath, isOutsideBaseDir, shouldBlock)
})
}
}
// TestPathCleanBehavior 测试 path.Clean 的行为
func TestPathCleanBehavior(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"/index.html", "/index.html"},
{"/../etc/passwd", "/etc/passwd"},
{"/./index.html", "/index.html"},
{"/css/../index.html", "/index.html"},
{"//double//slash", "/double/slash"},
{"/trailing/slash/", "/trailing/slash"},
{"/./././index.html", "/index.html"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := path.Clean(tt.input)
if result != tt.expected {
t.Errorf("path.Clean(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
// TestAbsolutePathValidation 测试绝对路径验证
func TestAbsolutePathValidation(t *testing.T) {
// 使用临时目录进行测试
baseDir := t.TempDir()
tests := []struct {
name string
path string
shouldFail bool
}{
{
name: "File in base directory",
path: "index.html",
shouldFail: false,
},
{
name: "File in subdirectory",
path: "css/style.css",
shouldFail: false,
},
{
name: "Attempt to escape with parent",
path: "../outside.txt",
shouldFail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cleanPath := path.Clean(tt.path)
filePath := filepath.Join(baseDir, cleanPath)
absDir, err := filepath.Abs(baseDir)
if err != nil {
t.Fatalf("Failed to get absolute path: %v", err)
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
t.Fatalf("Failed to get absolute file path: %v", err)
}
isOutside := !strings.HasPrefix(absFilePath, absDir)
if tt.shouldFail && !isOutside {
t.Errorf("Expected path to be outside base dir, but it wasn't: %s", absFilePath)
}
if !tt.shouldFail && isOutside {
t.Errorf("Expected path to be inside base dir, but it wasn't: %s", absFilePath)
}
})
}
}
// BenchmarkPathValidation 性能基准测试
func BenchmarkPathValidation(b *testing.B) {
baseDir := "/var/www/html"
testPath := "/css/style.css"
b.ResetTimer()
for i := 0; i < b.N; i++ {
cleanPath := path.Clean(testPath)
_ = strings.Contains(cleanPath, "..")
filePath := filepath.Join(baseDir, cleanPath)
absDir, _ := filepath.Abs(baseDir)
absFilePath, _ := filepath.Abs(filePath)
_ = strings.HasPrefix(absFilePath, absDir)
}
}
// BenchmarkPathValidationMalicious 恶意路径的性能测试
func BenchmarkPathValidationMalicious(b *testing.B) {
baseDir := "/var/www/html"
testPath := "/../../../etc/passwd"
b.ResetTimer()
for i := 0; i < b.N; i++ {
cleanPath := path.Clean(testPath)
_ = strings.Contains(cleanPath, "..")
filePath := filepath.Join(baseDir, cleanPath)
absDir, _ := filepath.Abs(baseDir)
absFilePath, _ := filepath.Abs(filePath)
_ = strings.HasPrefix(absFilePath, absDir)
}
}

View File

@ -95,7 +95,7 @@ func GetLocalIP() (string, error) {
} }
var candidates []Iface var candidates []Iface
for _, ifc := range ifaces { for _, ifc := range ifaces {
if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagUp == 0 { if ifc.Flags&net.FlagUp == 0 {
continue continue
} }
if ifc.Flags&(net.FlagPointToPoint|net.FlagLoopback) != 0 { if ifc.Flags&(net.FlagPointToPoint|net.FlagLoopback) != 0 {

154
pkg/config/config_test.go Normal file
View File

@ -0,0 +1,154 @@
package config
import (
"os"
"testing"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
// 测试 Common 配置
if cfg.Common.LogLevel != "info" {
t.Errorf("Expected log level 'info', got '%s'", cfg.Common.LogLevel)
}
if cfg.Common.LogFile != "app.log" {
t.Errorf("Expected log file 'app.log', got '%s'", cfg.Common.LogFile)
}
// 测试 GB28181 配置
if cfg.GB28181.Serial != "34020000002000000001" {
t.Errorf("Expected serial '34020000002000000001', got '%s'", cfg.GB28181.Serial)
}
if cfg.GB28181.Realm != "3402000000" {
t.Errorf("Expected realm '3402000000', got '%s'", cfg.GB28181.Realm)
}
if cfg.GB28181.Host != "0.0.0.0" {
t.Errorf("Expected host '0.0.0.0', got '%s'", cfg.GB28181.Host)
}
if cfg.GB28181.Port != 5060 {
t.Errorf("Expected port 5060, got %d", cfg.GB28181.Port)
}
if cfg.GB28181.Auth.Enable != false {
t.Errorf("Expected auth enable false, got %v", cfg.GB28181.Auth.Enable)
}
if cfg.GB28181.Auth.Password != "123456" {
t.Errorf("Expected auth password '123456', got '%s'", cfg.GB28181.Auth.Password)
}
// 测试 HTTP 配置
if cfg.Http.Port != 8025 {
t.Errorf("Expected http port 8025, got %d", cfg.Http.Port)
}
if cfg.Http.Dir != "./html" {
t.Errorf("Expected http dir './html', got '%s'", cfg.Http.Dir)
}
}
func TestLoadConfigNonExistent(t *testing.T) {
// 测试加载不存在的配置文件,应该返回默认配置
cfg, err := LoadConfig("non_existent_config.yaml")
if err != nil {
t.Fatalf("Expected no error for non-existent config, got: %v", err)
}
// 应该返回默认配置
defaultCfg := DefaultConfig()
if cfg.Common.LogLevel != defaultCfg.Common.LogLevel {
t.Errorf("Expected default log level, got '%s'", cfg.Common.LogLevel)
}
}
func TestLoadConfigValid(t *testing.T) {
// 创建临时配置文件
tempFile := "test_config.yaml"
defer os.Remove(tempFile)
configContent := `common:
log-level: debug
log-file: test.log
gb28181:
serial: "12345678901234567890"
realm: "1234567890"
host: "127.0.0.1"
port: 5061
auth:
enable: true
password: "test123"
http:
listen: 9000
dir: "./test_html"
`
err := os.WriteFile(tempFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// 加载配置
cfg, err := LoadConfig(tempFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// 验证配置
if cfg.Common.LogLevel != "debug" {
t.Errorf("Expected log level 'debug', got '%s'", cfg.Common.LogLevel)
}
if cfg.Common.LogFile != "test.log" {
t.Errorf("Expected log file 'test.log', got '%s'", cfg.Common.LogFile)
}
if cfg.GB28181.Serial != "12345678901234567890" {
t.Errorf("Expected serial '12345678901234567890', got '%s'", cfg.GB28181.Serial)
}
if cfg.GB28181.Port != 5061 {
t.Errorf("Expected port 5061, got %d", cfg.GB28181.Port)
}
if cfg.GB28181.Auth.Enable != true {
t.Errorf("Expected auth enable true, got %v", cfg.GB28181.Auth.Enable)
}
if cfg.Http.Port != 9000 {
t.Errorf("Expected http port 9000, got %d", cfg.Http.Port)
}
}
func TestLoadConfigInvalid(t *testing.T) {
// 创建无效的配置文件
tempFile := "test_invalid_config.yaml"
defer os.Remove(tempFile)
invalidContent := `invalid yaml content: [[[`
err := os.WriteFile(tempFile, []byte(invalidContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// 加载配置应该失败
_, err = LoadConfig(tempFile)
if err == nil {
t.Error("Expected error for invalid config file, got nil")
}
}
func TestGetLocalIP(t *testing.T) {
ip, err := GetLocalIP()
// 在某些环境下可能没有网络接口,所以允许返回错误
if err != nil {
t.Logf("GetLocalIP returned error (may be expected in some environments): %v", err)
return
}
// 如果成功,验证返回的是有效的 IP 地址
if ip == "" {
t.Error("Expected non-empty IP address")
}
// 简单验证 IP 格式(应该包含点号)
if len(ip) < 7 { // 最短的 IP 是 0.0.0.0
t.Errorf("IP address seems invalid: %s", ip)
}
t.Logf("Local IP: %s", ip)
}

View File

@ -1,152 +1,152 @@
package db package db
import ( import (
"database/sql" "database/sql"
"sync" "sync"
"github.com/ossrs/srs-sip/pkg/models" "github.com/ossrs/srs-sip/pkg/models"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
var ( var (
instance *MediaServerDB instance *MediaServerDB
once sync.Once once sync.Once
) )
type MediaServerDB struct { type MediaServerDB struct {
models.MediaServerResponse models.MediaServerResponse
db *sql.DB db *sql.DB
} }
// GetInstance 返回 MediaServerDB 的单例实例 // GetInstance 返回 MediaServerDB 的单例实例
func GetInstance(dbPath string) (*MediaServerDB, error) { func GetInstance(dbPath string) (*MediaServerDB, error) {
var err error var err error
once.Do(func() { once.Do(func() {
instance, err = NewMediaServerDB(dbPath) instance, err = NewMediaServerDB(dbPath)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return instance, nil return instance, nil
} }
func NewMediaServerDB(dbPath string) (*MediaServerDB, error) { func NewMediaServerDB(dbPath string) (*MediaServerDB, error) {
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 创建媒体服务器表 // 创建媒体服务器表
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE IF NOT EXISTS media_servers ( CREATE TABLE IF NOT EXISTS media_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL, type TEXT NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
ip TEXT NOT NULL, ip TEXT NOT NULL,
port INTEGER NOT NULL, port INTEGER NOT NULL,
username TEXT, username TEXT,
password TEXT, password TEXT,
secret TEXT, secret TEXT,
is_default INTEGER NOT NULL DEFAULT 0, is_default INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP created_at DATETIME DEFAULT CURRENT_TIMESTAMP
) )
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &MediaServerDB{db: db}, nil return &MediaServerDB{db: db}, nil
} }
// GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器 // GetMediaServerByNameAndIP 根据名称和IP查询媒体服务器
func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) { func (m *MediaServerDB) GetMediaServerByNameAndIP(name, ip string) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := m.db.QueryRow(` err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE name = ? AND ip = ? FROM media_servers WHERE name = ? AND ip = ?
`, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) `, name, ip).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ms, nil return &ms, nil
} }
func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { func (m *MediaServerDB) AddMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
_, err := m.db.Exec(` _, err := m.db.Exec(`
INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default) INSERT INTO media_servers (name, type, ip, port, username, password, secret, is_default)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, name, serverType, ip, port, username, password, secret, isDefault) `, name, serverType, ip, port, username, password, secret, isDefault)
return err return err
} }
// AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新) // AddOrUpdateMediaServer 添加或更新媒体服务器(如果已存在则更新)
func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error { func (m *MediaServerDB) AddOrUpdateMediaServer(name, serverType, ip string, port int, username, password, secret string, isDefault int) error {
// 检查是否已存在 // 检查是否已存在
existing, err := m.GetMediaServerByNameAndIP(name, ip) existing, err := m.GetMediaServerByNameAndIP(name, ip)
if err == nil && existing != nil { if err == nil && existing != nil {
// 已存在,更新记录 // 已存在,更新记录
_, err = m.db.Exec(` _, err = m.db.Exec(`
UPDATE media_servers UPDATE media_servers
SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ? SET type = ?, port = ?, username = ?, password = ?, secret = ?, is_default = ?
WHERE name = ? AND ip = ? WHERE name = ? AND ip = ?
`, serverType, port, username, password, secret, isDefault, name, ip) `, serverType, port, username, password, secret, isDefault, name, ip)
return err return err
} }
// 不存在,插入新记录 // 不存在,插入新记录
return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault) return m.AddMediaServer(name, serverType, ip, port, username, password, secret, isDefault)
} }
func (m *MediaServerDB) DeleteMediaServer(id int) error { func (m *MediaServerDB) DeleteMediaServer(id int) error {
_, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id) _, err := m.db.Exec("DELETE FROM media_servers WHERE id = ?", id)
return err return err
} }
func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) { func (m *MediaServerDB) GetMediaServer(id int) (*models.MediaServerResponse, error) {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := m.db.QueryRow(` err := m.db.QueryRow(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers WHERE id = ? FROM media_servers WHERE id = ?
`, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) `, id).Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ms, nil return &ms, nil
} }
func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) { func (m *MediaServerDB) ListMediaServers() ([]models.MediaServerResponse, error) {
rows, err := m.db.Query(` rows, err := m.db.Query(`
SELECT id, name, type, ip, port, username, password, secret, is_default, created_at SELECT id, name, type, ip, port, username, password, secret, is_default, created_at
FROM media_servers ORDER BY created_at DESC FROM media_servers ORDER BY created_at DESC
`) `)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var servers []models.MediaServerResponse var servers []models.MediaServerResponse
for rows.Next() { for rows.Next() {
var ms models.MediaServerResponse var ms models.MediaServerResponse
err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt) err := rows.Scan(&ms.ID, &ms.Name, &ms.Type, &ms.IP, &ms.Port, &ms.Username, &ms.Password, &ms.Secret, &ms.IsDefault, &ms.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, ms) servers = append(servers, ms)
} }
return servers, nil return servers, nil
} }
func (m *MediaServerDB) SetDefaultMediaServer(id int) error { func (m *MediaServerDB) SetDefaultMediaServer(id int) error {
// 先将所有服务器设置为非默认 // 先将所有服务器设置为非默认
if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil { if _, err := m.db.Exec("UPDATE media_servers SET is_default = 0"); err != nil {
return err return err
} }
// 将指定ID的服务器设置为默认 // 将指定ID的服务器设置为默认
_, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id) _, err := m.db.Exec("UPDATE media_servers SET is_default = 1 WHERE id = ?", id)
return err return err
} }
func (m *MediaServerDB) Close() error { func (m *MediaServerDB) Close() error {
return m.db.Close() return m.db.Close()
} }

View File

@ -105,3 +105,212 @@ func TestAddMediaServerDuplicates(t *testing.T) {
t.Log("Test confirmed: Old AddMediaServer method creates duplicates") t.Log("Test confirmed: Old AddMediaServer method creates duplicates")
} }
func TestGetMediaServer(t *testing.T) {
dbPath := "./test_get_media_server.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加一个媒体服务器
err = db.AddMediaServer("TestServer", "ZLM", "192.168.1.200", 8080, "admin", "pass123", "secret123", 0)
if err != nil {
t.Fatalf("Failed to add media server: %v", err)
}
// 获取服务器列表以获得ID
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list media servers: %v", err)
}
if len(servers) == 0 {
t.Fatal("No servers found")
}
// 通过ID获取服务器
server, err := db.GetMediaServer(servers[0].ID)
if err != nil {
t.Fatalf("Failed to get media server: %v", err)
}
// 验证数据
if server.Name != "TestServer" {
t.Errorf("Expected name 'TestServer', got '%s'", server.Name)
}
if server.Type != "ZLM" {
t.Errorf("Expected type 'ZLM', got '%s'", server.Type)
}
if server.IP != "192.168.1.200" {
t.Errorf("Expected IP '192.168.1.200', got '%s'", server.IP)
}
if server.Port != 8080 {
t.Errorf("Expected port 8080, got %d", server.Port)
}
}
func TestGetMediaServerNotFound(t *testing.T) {
dbPath := "./test_get_not_found.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 尝试获取不存在的服务器
_, err = db.GetMediaServer(999)
if err == nil {
t.Error("Expected error when getting non-existent server, got nil")
}
}
func TestDeleteMediaServer(t *testing.T) {
dbPath := "./test_delete_media_server.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加两个服务器
err = db.AddMediaServer("Server1", "SRS", "192.168.1.1", 1985, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server1: %v", err)
}
err = db.AddMediaServer("Server2", "ZLM", "192.168.1.2", 8080, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server2: %v", err)
}
// 获取服务器列表
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
if len(servers) != 2 {
t.Fatalf("Expected 2 servers, got %d", len(servers))
}
// 删除第一个服务器
err = db.DeleteMediaServer(servers[0].ID)
if err != nil {
t.Fatalf("Failed to delete server: %v", err)
}
// 验证只剩一个服务器
servers, err = db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers after delete: %v", err)
}
if len(servers) != 1 {
t.Fatalf("Expected 1 server after delete, got %d", len(servers))
}
}
func TestSetDefaultMediaServer(t *testing.T) {
dbPath := "./test_set_default.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加三个服务器
err = db.AddMediaServer("Server1", "SRS", "192.168.1.1", 1985, "", "", "", 1)
if err != nil {
t.Fatalf("Failed to add server1: %v", err)
}
err = db.AddMediaServer("Server2", "ZLM", "192.168.1.2", 8080, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server2: %v", err)
}
err = db.AddMediaServer("Server3", "SRS", "192.168.1.3", 1985, "", "", "", 0)
if err != nil {
t.Fatalf("Failed to add server3: %v", err)
}
// 获取服务器列表
servers, err := db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
// 找到 Server2 的 ID
var server2ID int
for _, s := range servers {
if s.Name == "Server2" {
server2ID = s.ID
break
}
}
// 设置 Server2 为默认
err = db.SetDefaultMediaServer(server2ID)
if err != nil {
t.Fatalf("Failed to set default server: %v", err)
}
// 验证只有 Server2 是默认的
servers, err = db.ListMediaServers()
if err != nil {
t.Fatalf("Failed to list servers: %v", err)
}
defaultCount := 0
for _, s := range servers {
if s.IsDefault == 1 {
defaultCount++
if s.Name != "Server2" {
t.Errorf("Expected Server2 to be default, got %s", s.Name)
}
}
}
if defaultCount != 1 {
t.Errorf("Expected exactly 1 default server, got %d", defaultCount)
}
}
func TestGetMediaServerByNameAndIP(t *testing.T) {
dbPath := "./test_get_by_name_ip.db"
defer os.Remove(dbPath)
db, err := NewMediaServerDB(dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// 添加服务器
err = db.AddMediaServer("MyServer", "SRS", "10.0.0.1", 1985, "user", "pass", "secret", 0)
if err != nil {
t.Fatalf("Failed to add server: %v", err)
}
// 通过名称和IP查询
server, err := db.GetMediaServerByNameAndIP("MyServer", "10.0.0.1")
if err != nil {
t.Fatalf("Failed to get server by name and IP: %v", err)
}
if server.Name != "MyServer" || server.IP != "10.0.0.1" {
t.Errorf("Server data mismatch: %+v", server)
}
// 查询不存在的组合
_, err = db.GetMediaServerByNameAndIP("MyServer", "10.0.0.2")
if err == nil {
t.Error("Expected error for non-existent name/IP combination, got nil")
}
}

298
pkg/media/media_test.go Normal file
View File

@ -0,0 +1,298 @@
package media
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestApiRequest_Success(t *testing.T) {
// Create a test server that returns a successful response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
"data": map[string]string{
"message": "success",
},
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]string{"test": "data"}
var res map[string]interface{}
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
}
func TestApiRequest_GetMethod(t *testing.T) {
// Create a test server that checks the HTTP method
methodReceived := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
methodReceived = r.Method
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
// When req is nil, should use GET method
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if methodReceived != "GET" {
t.Errorf("Expected GET method, got %s", methodReceived)
}
}
func TestApiRequest_PostMethod(t *testing.T) {
// Create a test server that checks the HTTP method
methodReceived := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
methodReceived = r.Method
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]string{"test": "data"}
var res map[string]interface{}
// When req is not nil, should use POST method
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if methodReceived != "POST" {
t.Errorf("Expected POST method, got %s", methodReceived)
}
}
func TestApiRequest_ServerError(t *testing.T) {
// Create a test server that returns an error status code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for server error status code")
}
}
func TestApiRequest_NonZeroCode(t *testing.T) {
// Create a test server that returns a non-zero error code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 100,
"message": "error message",
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for non-zero code")
}
}
func TestApiRequest_InvalidJSON(t *testing.T) {
// Create a test server that returns invalid JSON
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("invalid json"))
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for invalid JSON")
}
}
func TestApiRequest_ContextCancellation(t *testing.T) {
// Create a test server that delays the response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
// Create a context that is already cancelled
ctx, cancel := context.WithCancel(context.Background())
cancel()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected error for cancelled context")
}
}
func TestApiRequest_InvalidURL(t *testing.T) {
ctx := context.Background()
var res map[string]interface{}
// Test with invalid URL
err := apiRequest(ctx, "://invalid-url", nil, &res)
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestApiRequest_Timeout(t *testing.T) {
// Create a test server that never responds
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(15 * time.Second) // Longer than the 10 second timeout
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err == nil {
t.Error("Expected timeout error")
}
}
func TestApiRequest_ComplexResponse(t *testing.T) {
// Create a test server that returns a complex response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
"data": map[string]interface{}{
"id": "12345",
"status": "active",
"items": []string{
"item1",
"item2",
"item3",
},
},
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
// Check that the response was properly unmarshaled
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
data, ok := res["data"].(map[string]interface{})
if !ok {
t.Error("Expected data to be a map")
} else {
if data["id"].(string) != "12345" {
t.Errorf("Expected id '12345', got %v", data["id"])
}
}
}
func TestApiRequest_EmptyResponse(t *testing.T) {
// Create a test server that returns minimal response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"code": 0,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
var res map[string]interface{}
err := apiRequest(ctx, server.URL, nil, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if res["code"].(float64) != 0 {
t.Errorf("Expected code 0, got %v", res["code"])
}
}
func TestApiRequest_WithRequestBody(t *testing.T) {
// Create a test server that echoes back the request
var receivedBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
response := map[string]interface{}{
"code": 0,
"echo": receivedBody,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
ctx := context.Background()
req := map[string]interface{}{
"id": "test-id",
"ssrc": "test-ssrc",
}
var res map[string]interface{}
err := apiRequest(ctx, server.URL, req, &res)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
// Check that the server received the request body
if receivedBody["id"].(string) != "test-id" {
t.Errorf("Expected id 'test-id', got %v", receivedBody["id"])
}
}

View File

@ -1,59 +1,59 @@
package media package media
import ( import (
"context" "context"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
) )
type Zlm struct { type Zlm struct {
Ctx context.Context Ctx context.Context
Schema string // The schema of ZLM, eg: http Schema string // The schema of ZLM, eg: http
Addr string // The address of ZLM, eg: localhost:8085 Addr string // The address of ZLM, eg: localhost:8085
Secret string // The secret of ZLM, eg: ZLMediaKit_secret Secret string // The secret of ZLM, eg: ZLMediaKit_secret
} }
// /index/api/openRtpServer // /index/api/openRtpServer
// secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2 // secret={{ZLMediaKit_secret}}&port=0&enable_tcp=1&stream_id=test2
func (z *Zlm) Publish(id, ssrc string) (int, error) { func (z *Zlm) Publish(id, ssrc string) (int, error) {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
Port int `json:"port"` Port int `json:"port"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/openRtpServer?secret="+z.Secret+"&port=0&enable_tcp=1&stream_id="+id+"&ssrc="+ssrc, nil, &res); err != nil {
return 0, errors.Wrapf(err, "gb/v1/publish") return 0, errors.Wrapf(err, "gb/v1/publish")
} }
return res.Port, nil return res.Port, nil
} }
// /index/api/closeRtpServer // /index/api/closeRtpServer
func (z *Zlm) Unpublish(id string) error { func (z *Zlm) Unpublish(id string) error {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/closeRtpServer?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return errors.Wrapf(err, "gb/v1/publish") return errors.Wrapf(err, "gb/v1/publish")
} }
return nil return nil
} }
// /index/api/getMediaList // /index/api/getMediaList
func (z *Zlm) GetStreamStatus(id string) (bool, error) { func (z *Zlm) GetStreamStatus(id string) (bool, error) {
res := struct { res := struct {
Code int `json:"code"` Code int `json:"code"`
}{} }{}
if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil { if err := apiRequest(z.Ctx, z.Schema+"://"+z.Addr+"/index/api/getMediaList?secret="+z.Secret+"&stream_id="+id, nil, &res); err != nil {
return false, errors.Wrapf(err, "gb/v1/publish") return false, errors.Wrapf(err, "gb/v1/publish")
} }
return res.Code == 0, nil return res.Code == 0, nil
} }
func (z *Zlm) GetAddr() string { func (z *Zlm) GetAddr() string {
return z.Addr return z.Addr
} }
func (z *Zlm) GetWebRTCAddr(id string) string { func (z *Zlm) GetWebRTCAddr(id string) string {
return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play" return "http://" + z.Addr + "/index/api/webrtc?app=rtp&stream=" + id + "&type=play"
} }

View File

@ -1,106 +1,106 @@
package models package models
import "encoding/xml" import "encoding/xml"
type Record struct { type Record struct {
DeviceID string `xml:"DeviceID" json:"device_id"` DeviceID string `xml:"DeviceID" json:"device_id"`
Name string `xml:"Name" json:"name"` Name string `xml:"Name" json:"name"`
FilePath string `xml:"FilePath" json:"file_path"` FilePath string `xml:"FilePath" json:"file_path"`
Address string `xml:"Address" json:"address"` Address string `xml:"Address" json:"address"`
StartTime string `xml:"StartTime" json:"start_time"` StartTime string `xml:"StartTime" json:"start_time"`
EndTime string `xml:"EndTime" json:"end_time"` EndTime string `xml:"EndTime" json:"end_time"`
Secrecy int `xml:"Secrecy" json:"secrecy"` Secrecy int `xml:"Secrecy" json:"secrecy"`
Type string `xml:"Type" json:"type"` Type string `xml:"Type" json:"type"`
} }
// Example XML structure for channel info: // Example XML structure for channel info:
// //
// <Item> // <Item>
// <DeviceID>34020000001320000002</DeviceID> // <DeviceID>34020000001320000002</DeviceID>
// <Name>209</Name> // <Name>209</Name>
// <Manufacturer>UNIVIEW</Manufacturer> // <Manufacturer>UNIVIEW</Manufacturer>
// <Model>HIC6622-IR@X33-VF</Model> // <Model>HIC6622-IR@X33-VF</Model>
// <Owner>IPC-B2202.7.11.230222</Owner> // <Owner>IPC-B2202.7.11.230222</Owner>
// <CivilCode>CivilCode</CivilCode> // <CivilCode>CivilCode</CivilCode>
// <Address>Address</Address> // <Address>Address</Address>
// <Parental>1</Parental> // <Parental>1</Parental>
// <ParentID>75015310072008100002</ParentID> // <ParentID>75015310072008100002</ParentID>
// <SafetyWay>0</SafetyWay> // <SafetyWay>0</SafetyWay>
// <RegisterWay>1</RegisterWay> // <RegisterWay>1</RegisterWay>
// <Secrecy>0</Secrecy> // <Secrecy>0</Secrecy>
// <Status>ON</Status> // <Status>ON</Status>
// <Longitude>0.0000000</Longitude> // <Longitude>0.0000000</Longitude>
// <Latitude>0.0000000</Latitude> // <Latitude>0.0000000</Latitude>
// <Info> // <Info>
// <PTZType>1</PTZType> // <PTZType>1</PTZType>
// <Resolution>6/4/2</Resolution> // <Resolution>6/4/2</Resolution>
// <DownloadSpeed>0</DownloadSpeed> // <DownloadSpeed>0</DownloadSpeed>
// </Info> // </Info>
// </Item> // </Item>
type ChannelInfo struct { type ChannelInfo struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
ParentID string `json:"parent_id"` ParentID string `json:"parent_id"`
Name string `json:"name"` Name string `json:"name"`
Manufacturer string `json:"manufacturer"` Manufacturer string `json:"manufacturer"`
Model string `json:"model"` Model string `json:"model"`
Owner string `json:"owner"` Owner string `json:"owner"`
CivilCode string `json:"civil_code"` CivilCode string `json:"civil_code"`
Address string `json:"address"` Address string `json:"address"`
Port int `json:"port"` Port int `json:"port"`
Parental int `json:"parental"` Parental int `json:"parental"`
SafetyWay int `json:"safety_way"` SafetyWay int `json:"safety_way"`
RegisterWay int `json:"register_way"` RegisterWay int `json:"register_way"`
Secrecy int `json:"secrecy"` Secrecy int `json:"secrecy"`
IPAddress string `json:"ip_address"` IPAddress string `json:"ip_address"`
Status ChannelStatus `json:"status"` Status ChannelStatus `json:"status"`
Longitude float64 `json:"longitude"` Longitude float64 `json:"longitude"`
Latitude float64 `json:"latitude"` Latitude float64 `json:"latitude"`
Info struct { Info struct {
PTZType int `json:"ptz_type"` PTZType int `json:"ptz_type"`
Resolution string `json:"resolution"` Resolution string `json:"resolution"`
DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8 DownloadSpeed string `json:"download_speed"` // Speed levels: 1/2/4/8
} `json:"info"` } `json:"info"`
// Custom fields // Custom fields
Ssrc string `json:"ssrc"` Ssrc string `json:"ssrc"`
} }
type ChannelStatus string type ChannelStatus string
// BasicParam // BasicParam
// <! -- 基本参数配置(可选)--> // <! -- 基本参数配置(可选)-->
// <elementname="BasicParam"minOccurs="0"> // <elementname="BasicParam"minOccurs="0">
// <complexType> // <complexType>
// <sequence> // <sequence>
// <! -- 设备名称(可选)--> // <! -- 设备名称(可选)-->
// <elementname="Name"type="string" minOccurs="0"/> // <elementname="Name"type="string" minOccurs="0"/>
// <! -- 注册过期时间(可选)--> // <! -- 注册过期时间(可选)-->
// <elementname="Expiration"type="integer" minOccurs="0"/> // <elementname="Expiration"type="integer" minOccurs="0"/>
// <! -- 心跳间隔时间(可选)--> // <! -- 心跳间隔时间(可选)-->
// <elementname="HeartBeatInterval"type="integer" minOccurs="0"/> // <elementname="HeartBeatInterval"type="integer" minOccurs="0"/>
// <! -- 心跳超时次数(可选)--> // <! -- 心跳超时次数(可选)-->
// <elementname="HeartBeatCount"type="integer" minOccurs="0"/> // <elementname="HeartBeatCount"type="integer" minOccurs="0"/>
// </sequence> // </sequence>
// </complexType> // </complexType>
type BasicParam struct { type BasicParam struct {
Name string `xml:"Name"` Name string `xml:"Name"`
Expiration int `xml:"Expiration"` Expiration int `xml:"Expiration"`
HeartBeatInterval int `xml:"HeartBeatInterval"` HeartBeatInterval int `xml:"HeartBeatInterval"`
HeartBeatCount int `xml:"HeartBeatCount"` HeartBeatCount int `xml:"HeartBeatCount"`
} }
type XmlMessageInfo struct { type XmlMessageInfo struct {
XMLName xml.Name XMLName xml.Name
CmdType string CmdType string
SN int SN int
DeviceID string DeviceID string
DeviceName string DeviceName string
Manufacturer string Manufacturer string
Model string Model string
Channel string Channel string
DeviceList []ChannelInfo `xml:"DeviceList>Item"` DeviceList []ChannelInfo `xml:"DeviceList>Item"`
RecordList []*Record `xml:"RecordList>Item"` RecordList []*Record `xml:"RecordList>Item"`
BasicParam BasicParam `xml:"BasicParam"` BasicParam BasicParam `xml:"BasicParam"`
SumNum int SumNum int
} }

View File

@ -1,80 +1,80 @@
package models package models
type BaseRequest struct { type BaseRequest struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
} }
type InviteRequest struct { type InviteRequest struct {
BaseRequest BaseRequest
MediaServerId int `json:"media_server_id"` MediaServerId int `json:"media_server_id"`
PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download PlayType int `json:"play_type"` // 0: live, 1: playback, 2: download
SubStream int `json:"sub_stream"` SubStream int `json:"sub_stream"`
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"` EndTime int64 `json:"end_time"`
} }
type InviteResponse struct { type InviteResponse struct {
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
URL string `json:"url"` URL string `json:"url"`
} }
type SessionRequest struct { type SessionRequest struct {
BaseRequest BaseRequest
URL string `json:"url"` URL string `json:"url"`
} }
type ByeRequest struct { type ByeRequest struct {
SessionRequest SessionRequest
} }
type PauseRequest struct { type PauseRequest struct {
SessionRequest SessionRequest
} }
type ResumeRequest struct { type ResumeRequest struct {
SessionRequest SessionRequest
} }
type SpeedRequest struct { type SpeedRequest struct {
SessionRequest SessionRequest
Speed float32 `json:"speed"` Speed float32 `json:"speed"`
} }
type PTZControlRequest struct { type PTZControlRequest struct {
BaseRequest BaseRequest
PTZ string `json:"ptz"` PTZ string `json:"ptz"`
Speed string `json:"speed"` Speed string `json:"speed"`
} }
type QueryRecordRequest struct { type QueryRecordRequest struct {
BaseRequest BaseRequest
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"` EndTime int64 `json:"end_time"`
} }
type MediaServer struct { type MediaServer struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
IP string `json:"ip"` IP string `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Secret string `json:"secret"` Secret string `json:"secret"`
IsDefault int `json:"is_default"` IsDefault int `json:"is_default"`
} }
type MediaServerRequest struct { type MediaServerRequest struct {
MediaServer MediaServer
} }
type MediaServerResponse struct { type MediaServerResponse struct {
MediaServer MediaServer
ID int `json:"id"` ID int `json:"id"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
} }
type CommonResponse struct { type CommonResponse struct {
Code int `json:"code"` Code int `json:"code"`
Data interface{} `json:"data"` Data interface{} `json:"data"`
} }

337
pkg/models/types_test.go Normal file
View File

@ -0,0 +1,337 @@
package models
import (
"encoding/json"
"testing"
)
func TestBaseRequest(t *testing.T) {
req := BaseRequest{
DeviceID: "34020000001320000001",
ChannelID: "34020000001320000002",
}
if req.DeviceID != "34020000001320000001" {
t.Errorf("Expected DeviceID '34020000001320000001', got '%s'", req.DeviceID)
}
if req.ChannelID != "34020000001320000002" {
t.Errorf("Expected ChannelID '34020000001320000002', got '%s'", req.ChannelID)
}
}
func TestInviteRequest(t *testing.T) {
req := InviteRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
MediaServerId: 1,
PlayType: 0,
SubStream: 0,
StartTime: 1234567890,
EndTime: 1234567900,
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
if req.MediaServerId != 1 {
t.Errorf("Expected MediaServerId 1, got %d", req.MediaServerId)
}
if req.PlayType != 0 {
t.Errorf("Expected PlayType 0, got %d", req.PlayType)
}
}
func TestInviteRequestJSON(t *testing.T) {
jsonStr := `{
"device_id": "device123",
"channel_id": "channel123",
"media_server_id": 1,
"play_type": 1,
"sub_stream": 0,
"start_time": 1234567890,
"end_time": 1234567900
}`
var req InviteRequest
err := json.Unmarshal([]byte(jsonStr), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
if req.PlayType != 1 {
t.Errorf("Expected PlayType 1, got %d", req.PlayType)
}
}
func TestInviteResponse(t *testing.T) {
resp := InviteResponse{
ChannelID: "channel123",
URL: "webrtc://example.com/live/stream",
}
if resp.ChannelID != "channel123" {
t.Errorf("Expected ChannelID 'channel123', got '%s'", resp.ChannelID)
}
if resp.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", resp.URL)
}
}
func TestPTZControlRequest(t *testing.T) {
req := PTZControlRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
PTZ: "up",
Speed: "5",
}
if req.PTZ != "up" {
t.Errorf("Expected PTZ 'up', got '%s'", req.PTZ)
}
if req.Speed != "5" {
t.Errorf("Expected Speed '5', got '%s'", req.Speed)
}
}
func TestQueryRecordRequest(t *testing.T) {
req := QueryRecordRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
StartTime: 1234567890,
EndTime: 1234567900,
}
if req.StartTime != 1234567890 {
t.Errorf("Expected StartTime 1234567890, got %d", req.StartTime)
}
if req.EndTime != 1234567900 {
t.Errorf("Expected EndTime 1234567900, got %d", req.EndTime)
}
}
func TestMediaServer(t *testing.T) {
ms := MediaServer{
Name: "SRS Server",
Type: "SRS",
IP: "192.168.1.100",
Port: 1985,
Username: "admin",
Password: "password",
Secret: "secret",
IsDefault: 1,
}
if ms.Name != "SRS Server" {
t.Errorf("Expected Name 'SRS Server', got '%s'", ms.Name)
}
if ms.Type != "SRS" {
t.Errorf("Expected Type 'SRS', got '%s'", ms.Type)
}
if ms.Port != 1985 {
t.Errorf("Expected Port 1985, got %d", ms.Port)
}
if ms.IsDefault != 1 {
t.Errorf("Expected IsDefault 1, got %d", ms.IsDefault)
}
}
func TestMediaServerResponse(t *testing.T) {
resp := MediaServerResponse{
MediaServer: MediaServer{
Name: "Test Server",
Type: "ZLM",
IP: "10.0.0.1",
Port: 8080,
},
ID: 1,
CreatedAt: "2024-01-01 12:00:00",
}
if resp.ID != 1 {
t.Errorf("Expected ID 1, got %d", resp.ID)
}
if resp.CreatedAt != "2024-01-01 12:00:00" {
t.Errorf("Expected CreatedAt '2024-01-01 12:00:00', got '%s'", resp.CreatedAt)
}
}
func TestCommonResponse(t *testing.T) {
resp := CommonResponse{
Code: 0,
Data: map[string]string{"key": "value"},
}
if resp.Code != 0 {
t.Errorf("Expected Code 0, got %d", resp.Code)
}
// 测试 JSON 序列化
jsonData, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
var decoded CommonResponse
err = json.Unmarshal(jsonData, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if decoded.Code != 0 {
t.Errorf("Expected decoded Code 0, got %d", decoded.Code)
}
}
func TestSessionRequest(t *testing.T) {
req := SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
}
if req.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", req.URL)
}
}
func TestByeRequest(t *testing.T) {
req := ByeRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.DeviceID != "device123" {
t.Errorf("Expected DeviceID 'device123', got '%s'", req.DeviceID)
}
}
func TestPauseRequest(t *testing.T) {
req := PauseRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.URL != "webrtc://example.com/live/stream" {
t.Errorf("Expected URL 'webrtc://example.com/live/stream', got '%s'", req.URL)
}
}
func TestResumeRequest(t *testing.T) {
req := ResumeRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
}
if req.ChannelID != "channel123" {
t.Errorf("Expected ChannelID 'channel123', got '%s'", req.ChannelID)
}
}
func TestSpeedRequest(t *testing.T) {
req := SpeedRequest{
SessionRequest: SessionRequest{
BaseRequest: BaseRequest{
DeviceID: "device123",
ChannelID: "channel123",
},
URL: "webrtc://example.com/live/stream",
},
Speed: 2.0,
}
if req.Speed != 2.0 {
t.Errorf("Expected Speed 2.0, got %f", req.Speed)
}
}
func TestMediaServerRequestJSON(t *testing.T) {
jsonStr := `{
"name": "Test Server",
"type": "SRS",
"ip": "192.168.1.100",
"port": 1985,
"username": "admin",
"password": "pass123",
"secret": "secret123",
"is_default": 1
}`
var req MediaServerRequest
err := json.Unmarshal([]byte(jsonStr), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if req.Name != "Test Server" {
t.Errorf("Expected Name 'Test Server', got '%s'", req.Name)
}
if req.Type != "SRS" {
t.Errorf("Expected Type 'SRS', got '%s'", req.Type)
}
if req.Port != 1985 {
t.Errorf("Expected Port 1985, got %d", req.Port)
}
}
func TestCommonResponseWithDifferentDataTypes(t *testing.T) {
tests := []struct {
name string
data interface{}
}{
{"String data", "test string"},
{"Integer data", 123},
{"Map data", map[string]interface{}{"key": "value"}},
{"Array data", []string{"item1", "item2"}},
{"Nil data", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := CommonResponse{
Code: 0,
Data: tt.data,
}
jsonData, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
var decoded CommonResponse
err = json.Unmarshal(jsonData, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if decoded.Code != 0 {
t.Errorf("Expected Code 0, got %d", decoded.Code)
}
})
}
}

View File

@ -1,92 +1,92 @@
package service package service
import ( import (
"crypto/md5" "crypto/md5"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strings" "strings"
) )
// AuthInfo 存储解析后的认证信息 // AuthInfo 存储解析后的认证信息
type AuthInfo struct { type AuthInfo struct {
Username string Username string
Realm string Realm string
Nonce string Nonce string
URI string URI string
Response string Response string
Algorithm string Algorithm string
Method string Method string
} }
// GenerateNonce 生成随机 nonce 字符串 // GenerateNonce 生成随机 nonce 字符串
func GenerateNonce() string { func GenerateNonce() string {
b := make([]byte, 16) b := make([]byte, 16)
rand.Read(b) rand.Read(b)
return fmt.Sprintf("%x", b) return fmt.Sprintf("%x", b)
} }
// ParseAuthorization 解析 SIP Authorization 头 // ParseAuthorization 解析 SIP Authorization 头
// Authorization: Digest username="34020000001320000001",realm="3402000000", // Authorization: Digest username="34020000001320000001",realm="3402000000",
// nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000", // nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",
// response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5 // response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5
func ParseAuthorization(auth string) *AuthInfo { func ParseAuthorization(auth string) *AuthInfo {
auth = strings.TrimPrefix(auth, "Digest ") auth = strings.TrimPrefix(auth, "Digest ")
parts := strings.Split(auth, ",") parts := strings.Split(auth, ",")
result := &AuthInfo{} result := &AuthInfo{}
for _, part := range parts { for _, part := range parts {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
if !strings.Contains(part, "=") { if !strings.Contains(part, "=") {
continue continue
} }
kv := strings.SplitN(part, "=", 2) kv := strings.SplitN(part, "=", 2)
key := strings.TrimSpace(kv[0]) key := strings.TrimSpace(kv[0])
value := strings.Trim(strings.TrimSpace(kv[1]), "\"") value := strings.Trim(strings.TrimSpace(kv[1]), "\"")
switch key { switch key {
case "username": case "username":
result.Username = value result.Username = value
case "realm": case "realm":
result.Realm = value result.Realm = value
case "nonce": case "nonce":
result.Nonce = value result.Nonce = value
case "uri": case "uri":
result.URI = value result.URI = value
case "response": case "response":
result.Response = value result.Response = value
case "algorithm": case "algorithm":
result.Algorithm = value result.Algorithm = value
} }
} }
return result return result
} }
// ValidateAuth 验证 SIP 认证信息 // ValidateAuth 验证 SIP 认证信息
func ValidateAuth(authInfo *AuthInfo, password string) bool { func ValidateAuth(authInfo *AuthInfo, password string) bool {
if authInfo == nil { if authInfo == nil {
return false return false
} }
// 默认方法为 REGISTER // 默认方法为 REGISTER
method := "REGISTER" method := "REGISTER"
if authInfo.Method != "" { if authInfo.Method != "" {
method = authInfo.Method method = authInfo.Method
} }
// 计算 MD5 哈希 // 计算 MD5 哈希
ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password) ha1 := md5Hex(authInfo.Username + ":" + authInfo.Realm + ":" + password)
ha2 := md5Hex(method + ":" + authInfo.URI) ha2 := md5Hex(method + ":" + authInfo.URI)
correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2) correctResponse := md5Hex(ha1 + ":" + authInfo.Nonce + ":" + ha2)
return authInfo.Response == correctResponse return authInfo.Response == correctResponse
} }
// md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串 // md5Hex 计算字符串的 MD5 哈希值并返回十六进制字符串
func md5Hex(s string) string { func md5Hex(s string) string {
hash := md5.New() hash := md5.New()
hash.Write([]byte(s)) hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }

345
pkg/service/auth_test.go Normal file
View File

@ -0,0 +1,345 @@
package service
import (
"strings"
"testing"
)
func TestGenerateNonce(t *testing.T) {
// 生成多个 nonce 并验证
nonces := make(map[string]bool)
iterations := 100
for i := 0; i < iterations; i++ {
nonce := GenerateNonce()
// 验证长度16字节的十六进制表示应该是32个字符
if len(nonce) != 32 {
t.Errorf("Expected nonce length 32, got %d", len(nonce))
}
// 验证是否为十六进制字符串
for _, c := range nonce {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("Nonce contains non-hex character: %c", c)
}
}
nonces[nonce] = true
}
// 验证唯一性(应该生成不同的 nonce
if len(nonces) < 95 { // 允许极小概率的重复
t.Errorf("Expected at least 95 unique nonces out of %d, got %d", iterations, len(nonces))
}
}
func TestParseAuthorization(t *testing.T) {
tests := []struct {
name string
auth string
expected *AuthInfo
}{
{
name: "Complete authorization header",
auth: `Digest username="34020000001320000001",realm="3402000000",nonce="44010b73623249f6916a6acf7c316b8e",uri="sip:34020000002000000001@3402000000",response="e4ca3fdc5869fa1c544ea7af60014444",algorithm=MD5`,
expected: &AuthInfo{
Username: "34020000001320000001",
Realm: "3402000000",
Nonce: "44010b73623249f6916a6acf7c316b8e",
URI: "sip:34020000002000000001@3402000000",
Response: "e4ca3fdc5869fa1c544ea7af60014444",
Algorithm: "MD5",
},
},
{
name: "Authorization with spaces",
auth: `Digest username = "user123" , realm = "realm123" , nonce = "nonce123" , uri = "sip:test@example.com" , response = "resp123"`,
expected: &AuthInfo{
Username: "user123",
Realm: "realm123",
Nonce: "nonce123",
URI: "sip:test@example.com",
Response: "resp123",
},
},
{
name: "Partial authorization",
auth: `Digest username="testuser",realm="testrealm"`,
expected: &AuthInfo{
Username: "testuser",
Realm: "testrealm",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseAuthorization(tt.auth)
if result.Username != tt.expected.Username {
t.Errorf("Username: expected %s, got %s", tt.expected.Username, result.Username)
}
if result.Realm != tt.expected.Realm {
t.Errorf("Realm: expected %s, got %s", tt.expected.Realm, result.Realm)
}
if result.Nonce != tt.expected.Nonce {
t.Errorf("Nonce: expected %s, got %s", tt.expected.Nonce, result.Nonce)
}
if result.URI != tt.expected.URI {
t.Errorf("URI: expected %s, got %s", tt.expected.URI, result.URI)
}
if result.Response != tt.expected.Response {
t.Errorf("Response: expected %s, got %s", tt.expected.Response, result.Response)
}
if result.Algorithm != tt.expected.Algorithm {
t.Errorf("Algorithm: expected %s, got %s", tt.expected.Algorithm, result.Algorithm)
}
})
}
}
func TestParseAuthorizationEdgeCases(t *testing.T) {
tests := []struct {
name string
auth string
}{
{"Empty string", ""},
{"Only Digest", "Digest "},
{"Invalid format", "invalid format"},
{"No equals sign", "Digest username"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseAuthorization(tt.auth)
// 不应该 panic应该返回一个空的 AuthInfo
if result == nil {
t.Error("Expected non-nil result")
}
})
}
}
func TestMd5Hex(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Simple string",
input: "hello",
expected: "5d41402abc4b2a76b9719d911017c592",
},
{
name: "Empty string",
input: "",
expected: "d41d8cd98f00b204e9800998ecf8427e",
},
{
name: "Numbers",
input: "123456",
expected: "e10adc3949ba59abbe56e057f20f883e",
},
{
name: "Complex string",
input: "username:realm:password",
expected: "8e8d14bf0c4b87c1c5b8b1e8c8e8d14b", // 这个需要实际计算
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := md5Hex(tt.input)
// 验证长度MD5 哈希应该是32个字符
if len(result) != 32 {
t.Errorf("Expected MD5 hash length 32, got %d", len(result))
}
// 验证是否为十六进制字符串
for _, c := range result {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("MD5 hash contains non-hex character: %c", c)
}
}
// 对于已知的测试用例,验证具体值
if tt.name != "Complex string" && result != tt.expected {
t.Errorf("Expected MD5 hash %s, got %s", tt.expected, result)
}
})
}
}
func TestValidateAuth(t *testing.T) {
// 测试用例:使用已知的认证信息
t.Run("Valid authentication", func(t *testing.T) {
// 构造一个已知的认证场景
username := "testuser"
realm := "testrealm"
password := "testpass"
nonce := "testnonce"
uri := "sip:test@example.com"
method := "REGISTER"
// 计算正确的 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex(method + ":" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: method,
}
if !ValidateAuth(authInfo, password) {
t.Error("Expected authentication to be valid")
}
})
t.Run("Invalid password", func(t *testing.T) {
username := "testuser"
realm := "testrealm"
password := "testpass"
wrongPassword := "wrongpass"
nonce := "testnonce"
uri := "sip:test@example.com"
method := "REGISTER"
// 使用正确密码计算 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex(method + ":" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: method,
}
// 使用错误密码验证
if ValidateAuth(authInfo, wrongPassword) {
t.Error("Expected authentication to fail with wrong password")
}
})
t.Run("Nil authInfo", func(t *testing.T) {
if ValidateAuth(nil, "password") {
t.Error("Expected authentication to fail with nil authInfo")
}
})
t.Run("Default method", func(t *testing.T) {
// 测试当 Method 为空时,默认使用 REGISTER
username := "testuser"
realm := "testrealm"
password := "testpass"
nonce := "testnonce"
uri := "sip:test@example.com"
// 使用默认方法 REGISTER 计算 response
ha1 := md5Hex(username + ":" + realm + ":" + password)
ha2 := md5Hex("REGISTER:" + uri)
correctResponse := md5Hex(ha1 + ":" + nonce + ":" + ha2)
authInfo := &AuthInfo{
Username: username,
Realm: realm,
Nonce: nonce,
URI: uri,
Response: correctResponse,
Method: "", // 空方法,应该使用默认的 REGISTER
}
if !ValidateAuth(authInfo, password) {
t.Error("Expected authentication to be valid with default method")
}
})
}
func TestAuthInfoStruct(t *testing.T) {
// 测试 AuthInfo 结构体的基本功能
authInfo := &AuthInfo{
Username: "user",
Realm: "realm",
Nonce: "nonce",
URI: "uri",
Response: "response",
Algorithm: "MD5",
Method: "REGISTER",
}
if authInfo.Username != "user" {
t.Errorf("Expected username 'user', got '%s'", authInfo.Username)
}
if authInfo.Algorithm != "MD5" {
t.Errorf("Expected algorithm 'MD5', got '%s'", authInfo.Algorithm)
}
}
func TestParseAuthorizationWithoutDigestPrefix(t *testing.T) {
// 测试没有 "Digest " 前缀的情况
auth := `username="testuser",realm="testrealm"`
result := ParseAuthorization(auth)
if result.Username != "testuser" {
t.Errorf("Expected username 'testuser', got '%s'", result.Username)
}
if result.Realm != "testrealm" {
t.Errorf("Expected realm 'testrealm', got '%s'", result.Realm)
}
}
func TestParseAuthorizationCaseInsensitive(t *testing.T) {
// 虽然当前实现是大小写敏感的,但这个测试可以帮助未来改进
auth := `Digest username="testuser",realm="testrealm"`
result := ParseAuthorization(auth)
if result.Username == "" {
t.Error("Failed to parse username")
}
}
func TestMd5HexConsistency(t *testing.T) {
// 测试相同输入产生相同输出
input := "test string"
result1 := md5Hex(input)
result2 := md5Hex(input)
if result1 != result2 {
t.Errorf("MD5 hash should be consistent: %s != %s", result1, result2)
}
}
func TestMd5HexDifferentInputs(t *testing.T) {
// 测试不同输入产生不同输出
result1 := md5Hex("input1")
result2 := md5Hex("input2")
if result1 == result2 {
t.Error("Different inputs should produce different MD5 hashes")
}
}
func TestParseAuthorizationQuotedValues(t *testing.T) {
// 测试带引号和不带引号的值
auth := `Digest username="quoted",realm=unquoted,nonce="also-quoted"`
result := ParseAuthorization(auth)
if result.Username != "quoted" {
t.Errorf("Expected username 'quoted', got '%s'", result.Username)
}
// realm 没有引号,应该也能正确解析
if !strings.Contains(result.Realm, "unquoted") {
t.Logf("Realm value: '%s'", result.Realm)
}
}

View File

@ -1,17 +1,17 @@
package service package service
import ( import (
"context" "context"
"github.com/emiago/sipgo" "github.com/emiago/sipgo"
"github.com/ossrs/srs-sip/pkg/config" "github.com/ossrs/srs-sip/pkg/config"
) )
type Cascade struct { type Cascade struct {
ua *sipgo.UserAgent ua *sipgo.UserAgent
sipCli *sipgo.Client sipCli *sipgo.Client
sipSvr *sipgo.Server sipSvr *sipgo.Server
ctx context.Context ctx context.Context
conf *config.MainConfig conf *config.MainConfig
} }

View File

@ -1,171 +1,171 @@
package service package service
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/srs-sip/pkg/models" "github.com/ossrs/srs-sip/pkg/models"
"github.com/ossrs/srs-sip/pkg/service/stack" "github.com/ossrs/srs-sip/pkg/service/stack"
"golang.org/x/net/html/charset" "golang.org/x/net/html/charset"
) )
const GB28181_ID_LENGTH = 20 const GB28181_ID_LENGTH = 20
func (s *UAS) isSameIP(addr1, addr2 string) bool { func (s *UAS) isSameIP(addr1, addr2 string) bool {
ip1, _, err1 := net.SplitHostPort(addr1) ip1, _, err1 := net.SplitHostPort(addr1)
ip2, _, err2 := net.SplitHostPort(addr2) ip2, _, err2 := net.SplitHostPort(addr2)
// 如果解析出错,回退到完整字符串比较 // 如果解析出错,回退到完整字符串比较
if err1 != nil || err2 != nil { if err1 != nil || err2 != nil {
return addr1 == addr2 return addr1 == addr2
} }
return ip1 == ip2 return ip1 == ip2
} }
func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onRegister(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User id := req.From().Address.User
if len(id) != GB28181_ID_LENGTH { if len(id) != GB28181_ID_LENGTH {
slog.Error("invalid device ID") slog.Error("invalid device ID")
return return
} }
slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String())) slog.Debug(fmt.Sprintf("Received REGISTER %s", req.String()))
if s.conf.GB28181.Auth.Enable { if s.conf.GB28181.Auth.Enable {
// Check if Authorization header exists // Check if Authorization header exists
authHeader := req.GetHeaders("Authorization") authHeader := req.GetHeaders("Authorization")
// If no Authorization header, send 401 response to request authentication // If no Authorization header, send 401 response to request authentication
if len(authHeader) == 0 { if len(authHeader) == 0 {
nonce := GenerateNonce() nonce := GenerateNonce()
resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm) resp := stack.NewUnauthorizedResponse(req, http.StatusUnauthorized, "Unauthorized", nonce, s.conf.GB28181.Realm)
_ = tx.Respond(resp) _ = tx.Respond(resp)
return return
} }
// Validate Authorization // Validate Authorization
authInfo := ParseAuthorization(authHeader[0].Value()) authInfo := ParseAuthorization(authHeader[0].Value())
if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) { if !ValidateAuth(authInfo, s.conf.GB28181.Auth.Password) {
slog.Error("auth failed", "device_id", id, "source", req.Source()) slog.Error("auth failed", "device_id", id, "source", req.Source())
s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx) s.respondRegister(req, http.StatusForbidden, "Auth Failed", tx)
return return
} }
} }
isUnregister := false isUnregister := false
if exps := req.GetHeaders("Expires"); len(exps) > 0 { if exps := req.GetHeaders("Expires"); len(exps) > 0 {
exp := exps[0] exp := exps[0]
expSec, err := strconv.ParseInt(exp.Value(), 10, 32) expSec, err := strconv.ParseInt(exp.Value(), 10, 32)
if err != nil { if err != nil {
slog.Error("parse expires header error", "error", err.Error()) slog.Error("parse expires header error", "error", err.Error())
return return
} }
if expSec == 0 { if expSec == 0 {
isUnregister = true isUnregister = true
} }
} else { } else {
slog.Error("empty expires header") slog.Error("empty expires header")
return return
} }
if isUnregister { if isUnregister {
DM.RemoveDevice(id) DM.RemoveDevice(id)
slog.Warn("Device unregistered", "device_id", id) slog.Warn("Device unregistered", "device_id", id)
return return
} else { } else {
if d, ok := DM.GetDevice(id); !ok { if d, ok := DM.GetDevice(id); !ok {
DM.AddDevice(id, &DeviceInfo{ DM.AddDevice(id, &DeviceInfo{
DeviceID: id, DeviceID: id,
SourceAddr: req.Source(), SourceAddr: req.Source(),
NetworkType: req.Transport(), NetworkType: req.Transport(),
}) })
s.respondRegister(req, http.StatusOK, "OK", tx) s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source())) slog.Info(fmt.Sprintf("Register success %s %s", id, req.Source()))
go s.ConfigDownload(id) go s.ConfigDownload(id)
go s.Catalog(id) go s.Catalog(id)
} else { } else {
if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) { if d.SourceAddr != "" && !s.isSameIP(d.SourceAddr, req.Source()) {
slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source()) slog.Error("Device already registered", "device_id", id, "old_source", d.SourceAddr, "new_source", req.Source())
// TODO: 如果ID重复应采用虚拟ID // TODO: 如果ID重复应采用虚拟ID
s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx) s.respondRegister(req, http.StatusBadRequest, "Conflict Device ID", tx)
} else { } else {
d.SourceAddr = req.Source() d.SourceAddr = req.Source()
d.NetworkType = req.Transport() d.NetworkType = req.Transport()
DM.UpdateDevice(id, d) DM.UpdateDevice(id, d)
s.respondRegister(req, http.StatusOK, "OK", tx) s.respondRegister(req, http.StatusOK, "OK", tx)
slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source())) slog.Info(fmt.Sprintf("Re-register success %s %s", id, req.Source()))
} }
} }
} }
} }
func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) { func (s *UAS) respondRegister(req *sip.Request, code sip.StatusCode, reason string, tx sip.ServerTransaction) {
res := stack.NewRegisterResponse(req, code, reason) res := stack.NewRegisterResponse(req, code, reason)
_ = tx.Respond(res) _ = tx.Respond(res)
} }
func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onMessage(req *sip.Request, tx sip.ServerTransaction) {
id := req.From().Address.User id := req.From().Address.User
if len(id) != 20 { if len(id) != 20 {
slog.Error("invalid device ID", "request", req.String()) slog.Error("invalid device ID", "request", req.String())
} }
slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String())) slog.Debug(fmt.Sprintf("Received MESSAGE %s", req.String()))
temp := &models.XmlMessageInfo{} temp := &models.XmlMessageInfo{}
decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body()))) decoder := xml.NewDecoder(bytes.NewReader([]byte(req.Body())))
decoder.CharsetReader = charset.NewReaderLabel decoder.CharsetReader = charset.NewReaderLabel
if err := decoder.Decode(temp); err != nil { if err := decoder.Decode(temp); err != nil {
slog.Error("decode message error", "error", err.Error(), "message", req.Body()) slog.Error("decode message error", "error", err.Error(), "message", req.Body())
} }
slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source())) slog.Info(fmt.Sprintf("Received MESSAGE %s %s %s", temp.CmdType, temp.DeviceID, req.Source()))
var body string var body string
switch temp.CmdType { switch temp.CmdType {
case "Keepalive": case "Keepalive":
if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online { if d, ok := DM.GetDevice(temp.DeviceID); ok && d.Online {
// 更新设备心跳时间 // 更新设备心跳时间
DM.UpdateDeviceHeartbeat(temp.DeviceID) DM.UpdateDeviceHeartbeat(temp.DeviceID)
} else { } else {
tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)) tx.Respond(sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil))
return return
} }
case "SensorCatalog": // 兼容宇视,非国标 case "SensorCatalog": // 兼容宇视,非国标
case "Catalog": case "Catalog":
DM.UpdateChannels(temp.DeviceID, temp.DeviceList...) DM.UpdateChannels(temp.DeviceID, temp.DeviceList...)
//go s.AutoInvite(temp.DeviceID, temp.DeviceList...) //go s.AutoInvite(temp.DeviceID, temp.DeviceList...)
case "ConfigDownload": case "ConfigDownload":
DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam) DM.UpdateDeviceConfig(temp.DeviceID, &temp.BasicParam)
case "Alarm": case "Alarm":
slog.Info("Alarm") slog.Info("Alarm")
case "RecordInfo": case "RecordInfo":
// 从 recordQueryResults 中获取对应通道的结果通道 // 从 recordQueryResults 中获取对应通道的结果通道
if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok { if ch, ok := s.recordQueryResults.Load(temp.DeviceID); ok {
// 发送查询结果 // 发送查询结果
resultChan := ch.(chan *models.XmlMessageInfo) resultChan := ch.(chan *models.XmlMessageInfo)
resultChan <- temp resultChan <- temp
} }
default: default:
slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType) slog.Warn("Not supported CmdType", "cmd_type", temp.CmdType)
response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil) response := sip.NewResponseFromRequest(req, http.StatusBadRequest, "", nil)
tx.Respond(response) tx.Respond(response)
return return
} }
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body))) tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", []byte(body)))
} }
func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) { func (s *UAS) onNotify(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String())) slog.Debug(fmt.Sprintf("Received NOTIFY %s", req.String()))
tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil)) tx.Respond(sip.NewResponseFromRequest(req, http.StatusOK, "OK", nil))
} }

View File

@ -1,81 +1,81 @@
package service package service
import "fmt" import "fmt"
var ( var (
ptzCmdMap = map[string]uint8{ ptzCmdMap = map[string]uint8{
"stop": 0, "stop": 0,
"right": 1, "right": 1,
"left": 2, "left": 2,
"down": 4, "down": 4,
"downright": 5, "downright": 5,
"downleft": 6, "downleft": 6,
"up": 8, "up": 8,
"upright": 9, "upright": 9,
"upleft": 10, "upleft": 10,
"zoomin": 16, "zoomin": 16,
"zoomout": 32, "zoomout": 32,
} }
ptzSpeedMap = map[string]uint8{ ptzSpeedMap = map[string]uint8{
"1": 25, "1": 25,
"2": 50, "2": 50,
"3": 75, "3": 75,
"4": 100, "4": 100,
"5": 125, "5": 125,
"6": 150, "6": 150,
"7": 175, "7": 175,
"8": 200, "8": 200,
"9": 225, "9": 225,
"10": 255, "10": 255,
} }
defaultSpeed uint8 = 125 defaultSpeed uint8 = 125
) )
func getPTZSpeed(speed string) uint8 { func getPTZSpeed(speed string) uint8 {
if v, ok := ptzSpeedMap[speed]; ok { if v, ok := ptzSpeedMap[speed]; ok {
return v return v
} }
return defaultSpeed return defaultSpeed
} }
func toPTZCmd(cmdName, speed string) (string, error) { func toPTZCmd(cmdName, speed string) (string, error) {
cmdCode, ok := ptzCmdMap[cmdName] cmdCode, ok := ptzCmdMap[cmdName]
if !ok { if !ok {
return "", fmt.Errorf("invalid ptz command: %q", cmdName) return "", fmt.Errorf("invalid ptz command: %q", cmdName)
} }
speedValue := getPTZSpeed(speed) speedValue := getPTZSpeed(speed)
var horizontalSpeed, verticalSpeed, zSpeed uint8 var horizontalSpeed, verticalSpeed, zSpeed uint8
switch cmdName { switch cmdName {
case "left", "right": case "left", "right":
horizontalSpeed = speedValue horizontalSpeed = speedValue
verticalSpeed = 0 verticalSpeed = 0
case "up", "down": case "up", "down":
verticalSpeed = speedValue verticalSpeed = speedValue
horizontalSpeed = 0 horizontalSpeed = 0
case "upleft", "upright", "downleft", "downright": case "upleft", "upright", "downleft", "downright":
verticalSpeed = speedValue verticalSpeed = speedValue
horizontalSpeed = speedValue horizontalSpeed = speedValue
case "zoomin", "zoomout": case "zoomin", "zoomout":
zSpeed = speedValue << 4 // zoom速度在高4位 zSpeed = speedValue << 4 // zoom速度在高4位
default: default:
horizontalSpeed = 0 horizontalSpeed = 0
verticalSpeed = 0 verticalSpeed = 0
zSpeed = 0 zSpeed = 0
} }
sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed) sum := uint16(0xA5) + uint16(0x0F) + uint16(0x01) + uint16(cmdCode) + uint16(horizontalSpeed) + uint16(verticalSpeed) + uint16(zSpeed)
checksum := uint8(sum % 256) checksum := uint8(sum % 256)
return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X", return fmt.Sprintf("A50F01%02X%02X%02X%02X%02X",
cmdCode, cmdCode,
horizontalSpeed, horizontalSpeed,
verticalSpeed, verticalSpeed,
zSpeed, zSpeed,
checksum, checksum,
), nil ), nil
} }

198
pkg/service/ptz_test.go Normal file
View File

@ -0,0 +1,198 @@
package service
import (
"testing"
)
func TestGetPTZSpeed(t *testing.T) {
tests := []struct {
name string
speed string
expected uint8
}{
{"Speed 1", "1", 25},
{"Speed 2", "2", 50},
{"Speed 3", "3", 75},
{"Speed 4", "4", 100},
{"Speed 5", "5", 125},
{"Speed 6", "6", 150},
{"Speed 7", "7", 175},
{"Speed 8", "8", 200},
{"Speed 9", "9", 225},
{"Speed 10", "10", 255},
{"Invalid speed", "invalid", 125}, // 默认速度
{"Empty speed", "", 125}, // 默认速度
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getPTZSpeed(tt.speed)
if result != tt.expected {
t.Errorf("getPTZSpeed(%s) = %d, expected %d", tt.speed, result, tt.expected)
}
})
}
}
func TestToPTZCmd(t *testing.T) {
tests := []struct {
name string
cmdName string
speed string
expectError bool
checkPrefix bool
}{
{"Stop command", "stop", "5", false, true},
{"Right command", "right", "5", false, true},
{"Left command", "left", "5", false, true},
{"Up command", "up", "5", false, true},
{"Down command", "down", "5", false, true},
{"Up-right command", "upright", "5", false, true},
{"Up-left command", "upleft", "5", false, true},
{"Down-right command", "downright", "5", false, true},
{"Down-left command", "downleft", "5", false, true},
{"Zoom in command", "zoomin", "5", false, true},
{"Zoom out command", "zoomout", "5", false, true},
{"Invalid command", "invalid", "5", true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := toPTZCmd(tt.cmdName, tt.speed)
if tt.expectError {
if err == nil {
t.Errorf("Expected error for command %s, got nil", tt.cmdName)
}
return
}
if err != nil {
t.Errorf("Unexpected error for command %s: %v", tt.cmdName, err)
return
}
// 验证结果格式
if len(result) != 16 { // A50F01 + 5对字节 = 16个字符
t.Errorf("Expected result length 16, got %d for command %s", len(result), tt.cmdName)
}
// 验证前缀
if tt.checkPrefix && result[:6] != "A50F01" {
t.Errorf("Expected prefix 'A50F01', got '%s' for command %s", result[:6], tt.cmdName)
}
})
}
}
func TestToPTZCmdSpecificCases(t *testing.T) {
// 测试停止命令
t.Run("Stop command details", func(t *testing.T) {
result, err := toPTZCmd("stop", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Stop 命令码是 0速度应该都是 0
// A50F01 00 00 00 00 checksum
if result[:8] != "A50F0100" {
t.Errorf("Stop command should start with A50F0100, got %s", result[:8])
}
})
// 测试右移命令
t.Run("Right command details", func(t *testing.T) {
result, err := toPTZCmd("right", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Right 命令码是 1水平速度应该是 125 (0x7D)
// A50F01 01 7D 00 00 checksum
if result[:8] != "A50F0101" {
t.Errorf("Right command should start with A50F0101, got %s", result[:8])
}
})
// 测试上移命令
t.Run("Up command details", func(t *testing.T) {
result, err := toPTZCmd("up", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Up 命令码是 8垂直速度应该是 125 (0x7D)
// A50F01 08 00 7D 00 checksum
if result[:8] != "A50F0108" {
t.Errorf("Up command should start with A50F0108, got %s", result[:8])
}
})
// 测试缩放命令
t.Run("Zoom in command details", func(t *testing.T) {
result, err := toPTZCmd("zoomin", "5")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Zoom in 命令码是 16 (0x10)
// A50F01 10 00 00 XX checksum (XX 是速度左移4位)
if result[:8] != "A50F0110" {
t.Errorf("Zoom in command should start with A50F0110, got %s", result[:8])
}
})
}
func TestToPTZCmdWithDifferentSpeeds(t *testing.T) {
speeds := []string{"1", "5", "10"}
for _, speed := range speeds {
t.Run("Right with speed "+speed, func(t *testing.T) {
result, err := toPTZCmd("right", speed)
if err != nil {
t.Errorf("Unexpected error with speed %s: %v", speed, err)
}
if len(result) != 16 {
t.Errorf("Expected length 16, got %d", len(result))
}
})
}
}
func TestPTZCmdMap(t *testing.T) {
// 验证所有预定义的命令都存在
expectedCommands := []string{
"stop", "right", "left", "down", "downright", "downleft",
"up", "upright", "upleft", "zoomin", "zoomout",
}
for _, cmd := range expectedCommands {
t.Run("Command exists: "+cmd, func(t *testing.T) {
if _, ok := ptzCmdMap[cmd]; !ok {
t.Errorf("Command %s not found in ptzCmdMap", cmd)
}
})
}
}
func TestPTZSpeedMap(t *testing.T) {
// 验证速度映射的正确性
expectedSpeeds := map[string]uint8{
"1": 25,
"2": 50,
"3": 75,
"4": 100,
"5": 125,
"6": 150,
"7": 175,
"8": 200,
"9": 225,
"10": 255,
}
for speed, expectedValue := range expectedSpeeds {
t.Run("Speed mapping: "+speed, func(t *testing.T) {
if value, ok := ptzSpeedMap[speed]; !ok {
t.Errorf("Speed %s not found in ptzSpeedMap", speed)
} else if value != expectedValue {
t.Errorf("Speed %s expected value %d, got %d", speed, expectedValue, value)
}
})
}
}

View File

@ -1,68 +1,68 @@
package stack package stack
import ( import (
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
) )
type OutboundConfig struct { type OutboundConfig struct {
Transport string Transport string
Via string Via string
From string From string
To string To string
} }
func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) { func NewRequest(method sip.RequestMethod, body []byte, conf OutboundConfig) (*sip.Request, error) {
if len(conf.From) != 20 || len(conf.To) != 20 { if len(conf.From) != 20 || len(conf.To) != 20 {
return nil, errors.Errorf("From or To length is not 20") return nil, errors.Errorf("From or To length is not 20")
} }
dest := conf.Via dest := conf.Via
to := sip.Uri{User: conf.To, Host: conf.To[:10]} to := sip.Uri{User: conf.To, Host: conf.To[:10]}
from := &sip.Uri{User: conf.From, Host: conf.From[:10]} from := &sip.Uri{User: conf.From, Host: conf.From[:10]}
fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()} fromHeader := &sip.FromHeader{Address: *from, Params: sip.NewParams()}
fromHeader.Params.Add("tag", sip.GenerateTagN(16)) fromHeader.Params.Add("tag", sip.GenerateTagN(16))
req := sip.NewRequest(method, to) req := sip.NewRequest(method, to)
req.AppendHeader(fromHeader) req.AppendHeader(fromHeader)
req.AppendHeader(&sip.ToHeader{Address: to}) req.AppendHeader(&sip.ToHeader{Address: to})
req.AppendHeader(&sip.ContactHeader{Address: *from}) req.AppendHeader(&sip.ContactHeader{Address: *from})
req.AppendHeader(sip.NewHeader("Max-Forwards", "70")) req.AppendHeader(sip.NewHeader("Max-Forwards", "70"))
req.SetBody(body) req.SetBody(body)
req.SetDestination(dest) req.SetDestination(dest)
req.SetTransport(conf.Transport) req.SetTransport(conf.Transport)
return req, nil return req, nil
} }
func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) { func NewRegisterRequest(conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.REGISTER, nil, conf) req, err := NewRequest(sip.REGISTER, nil, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Expires", "3600")) req.AppendHeader(sip.NewHeader("Expires", "3600"))
return req, nil return req, nil
} }
func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) { func NewInviteRequest(body []byte, subject string, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.INVITE, body, conf) req, err := NewRequest(sip.INVITE, body, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp"))
req.AppendHeader(sip.NewHeader("Subject", subject)) req.AppendHeader(sip.NewHeader("Subject", subject))
return req, nil return req, nil
} }
func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) { func NewMessageRequest(body []byte, conf OutboundConfig) (*sip.Request, error) {
req, err := NewRequest(sip.MESSAGE, body, conf) req, err := NewRequest(sip.MESSAGE, body, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml")) req.AppendHeader(sip.NewHeader("Content-Type", "Application/MANSCDP+xml"))
return req, nil return req, nil
} }

View File

@ -0,0 +1,292 @@
package stack
import (
"testing"
"github.com/emiago/sipgo/sip"
)
func TestNewRequest(t *testing.T) {
tests := []struct {
name string
method sip.RequestMethod
body []byte
conf OutboundConfig
wantErr bool
errMsg string
checkFunc func(*testing.T, *sip.Request)
}{
{
name: "Valid REGISTER request",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.REGISTER {
t.Errorf("Expected method REGISTER, got %v", req.Method)
}
if req.Destination() != "192.168.1.100:5060" {
t.Errorf("Expected destination 192.168.1.100:5060, got %v", req.Destination())
}
if req.Transport() != "udp" {
t.Errorf("Expected transport udp, got %v", req.Transport())
}
},
},
{
name: "Valid INVITE request with body",
method: sip.INVITE,
body: []byte("v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\n"),
conf: OutboundConfig{
Transport: "tcp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.INVITE {
t.Errorf("Expected method INVITE, got %v", req.Method)
}
if req.Body() == nil {
t.Error("Expected body to be set")
}
},
},
{
name: "Invalid From length - too short",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "123456789",
To: "34020000001110000001",
},
wantErr: true,
errMsg: "From or To length is not 20",
},
{
name: "Invalid To length - too long",
method: sip.REGISTER,
body: nil,
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "340200000011100000012345",
},
wantErr: true,
errMsg: "From or To length is not 20",
},
{
name: "Valid MESSAGE request",
method: sip.MESSAGE,
body: []byte("<?xml version=\"1.0\"?><Query></Query>"),
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
},
wantErr: false,
checkFunc: func(t *testing.T, req *sip.Request) {
if req.Method != sip.MESSAGE {
t.Errorf("Expected method MESSAGE, got %v", req.Method)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.method, tt.body, tt.conf)
if tt.wantErr {
if err == nil {
t.Errorf("Expected error but got nil")
} else if tt.errMsg != "" && err.Error() != tt.errMsg {
t.Errorf("Expected error message '%s', got '%s'", tt.errMsg, err.Error())
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if req == nil {
t.Error("Expected request to be non-nil")
return
}
// Run custom checks
if tt.checkFunc != nil {
tt.checkFunc(t, req)
}
// Common checks
if req.From() == nil {
t.Error("Expected From header to be set")
}
if req.To() == nil {
t.Error("Expected To header to be set")
}
if req.Contact() == nil {
t.Error("Expected Contact header to be set")
}
})
}
}
func TestNewRegisterRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.REGISTER {
t.Errorf("Expected method REGISTER, got %v", req.Method)
}
// Check for Expires header
expires := req.GetHeader("Expires")
if expires == nil {
t.Error("Expected Expires header to be set")
} else if expires.Value() != "3600" {
t.Errorf("Expected Expires value 3600, got %v", expires.Value())
}
}
func TestNewInviteRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
body := []byte("v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\ns=Play\r\n")
subject := "34020000001320000001:0,34020000001110000001:0"
req, err := NewInviteRequest(body, subject, conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.INVITE {
t.Errorf("Expected method INVITE, got %v", req.Method)
}
// Check for Content-Type header
contentType := req.GetHeader("Content-Type")
if contentType == nil {
t.Error("Expected Content-Type header to be set")
} else if contentType.Value() != "application/sdp" {
t.Errorf("Expected Content-Type value application/sdp, got %v", contentType.Value())
}
// Check for Subject header
subjectHeader := req.GetHeader("Subject")
if subjectHeader == nil {
t.Error("Expected Subject header to be set")
} else if subjectHeader.Value() != subject {
t.Errorf("Expected Subject value %s, got %v", subject, subjectHeader.Value())
}
// Check body
if req.Body() == nil {
t.Error("Expected body to be set")
}
}
func TestNewMessageRequest(t *testing.T) {
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
body := []byte("<?xml version=\"1.0\"?><Query><CmdType>Catalog</CmdType></Query>")
req, err := NewMessageRequest(body, conf)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if req.Method != sip.MESSAGE {
t.Errorf("Expected method MESSAGE, got %v", req.Method)
}
// Check for Content-Type header
contentType := req.GetHeader("Content-Type")
if contentType == nil {
t.Error("Expected Content-Type header to be set")
} else if contentType.Value() != "Application/MANSCDP+xml" {
t.Errorf("Expected Content-Type value Application/MANSCDP+xml, got %v", contentType.Value())
}
// Check body
if req.Body() == nil {
t.Error("Expected body to be set")
}
}
func TestNewRequestWithInvalidConfig(t *testing.T) {
tests := []struct {
name string
conf OutboundConfig
}{
{
name: "Empty From",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "",
To: "34020000001110000001",
},
},
{
name: "Empty To",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "",
},
},
{
name: "From length 19",
conf: OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "3402000000132000001",
To: "34020000001110000001",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewRequest(sip.REGISTER, nil, tt.conf)
if err == nil {
t.Error("Expected error but got nil")
}
})
}
}

View File

@ -1,41 +1,41 @@
package stack package stack
import ( import (
"fmt" "fmt"
"time" "time"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
) )
const TIME_LAYOUT = "2024-01-01T00:00:00" const TIME_LAYOUT = "2024-01-01T00:00:00"
const EXPIRES_TIME = 3600 const EXPIRES_TIME = 3600
func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { func newResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := sip.NewResponseFromRequest(req, code, reason, nil) resp := sip.NewResponseFromRequest(req, code, reason, nil)
newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()} newTo := &sip.ToHeader{Address: resp.To().Address, Params: sip.NewParams()}
newTo.Params.Add("tag", sip.GenerateTagN(10)) newTo.Params.Add("tag", sip.GenerateTagN(10))
resp.ReplaceHeader(newTo) resp.ReplaceHeader(newTo)
resp.RemoveHeader("Allow") resp.RemoveHeader("Allow")
return resp return resp
} }
func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response { func NewRegisterResponse(req *sip.Request, code sip.StatusCode, reason string) *sip.Response {
resp := newResponse(req, code, reason) resp := newResponse(req, code, reason)
expires := sip.ExpiresHeader(EXPIRES_TIME) expires := sip.ExpiresHeader(EXPIRES_TIME)
resp.AppendHeader(&expires) resp.AppendHeader(&expires)
resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT))) resp.AppendHeader(sip.NewHeader("Date", time.Now().Format(TIME_LAYOUT)))
return resp return resp
} }
func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response { func NewUnauthorizedResponse(req *sip.Request, code sip.StatusCode, reason, nonce, realm string) *sip.Response {
resp := newResponse(req, code, reason) resp := newResponse(req, code, reason)
resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce))) resp.AppendHeader(sip.NewHeader("WWW-Authenticate", fmt.Sprintf(`Digest realm="%s",nonce="%s",algorithm=MD5`, realm, nonce)))
return resp return resp
} }

View File

@ -0,0 +1,296 @@
package stack
import (
"testing"
"github.com/emiago/sipgo/sip"
)
func TestNewRegisterResponse(t *testing.T) {
// Create a test request first - we need a properly initialized request
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
}{
{
name: "200 OK response",
code: sip.StatusOK,
reason: "OK",
},
{
name: "401 Unauthorized response",
code: sip.StatusUnauthorized,
reason: "Unauthorized",
},
{
name: "403 Forbidden response",
code: sip.StatusForbidden,
reason: "Forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := NewRegisterResponse(req, tt.code, tt.reason)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check for Expires header
expires := resp.GetHeader("Expires")
if expires == nil {
t.Error("Expected Expires header to be set")
} else if expires.Value() != "3600" {
t.Errorf("Expected Expires value 3600, got %v", expires.Value())
}
// Check for Date header
date := resp.GetHeader("Date")
if date == nil {
t.Error("Expected Date header to be set")
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
}
// Check that Allow header is removed
allow := resp.GetHeader("Allow")
if allow != nil {
t.Error("Expected Allow header to be removed")
}
})
}
}
func TestNewUnauthorizedResponse(t *testing.T) {
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRegisterRequest(conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
nonce string
realm string
}{
{
name: "401 Unauthorized with nonce and realm",
code: sip.StatusUnauthorized,
reason: "Unauthorized",
nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093",
realm: "3402000000",
},
{
name: "407 Proxy Authentication Required",
code: sip.StatusProxyAuthRequired,
reason: "Proxy Authentication Required",
nonce: "abc123def456",
realm: "proxy.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := NewUnauthorizedResponse(req, tt.code, tt.reason, tt.nonce, tt.realm)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check for WWW-Authenticate header
wwwAuth := resp.GetHeader("WWW-Authenticate")
if wwwAuth == nil {
t.Error("Expected WWW-Authenticate header to be set")
} else {
authValue := wwwAuth.Value()
// Check that it contains the nonce
if len(authValue) == 0 {
t.Error("Expected WWW-Authenticate header to have a value")
}
// The value should contain Digest, realm, nonce, and algorithm
expectedSubstrings := []string{"Digest", "realm=", "nonce=", "algorithm=MD5"}
for _, substr := range expectedSubstrings {
if len(authValue) > 0 && !contains(authValue, substr) {
t.Errorf("Expected WWW-Authenticate to contain '%s', got '%s'", substr, authValue)
}
}
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
}
})
}
}
func TestNewResponse(t *testing.T) {
// Skip this test as it requires a full SIP stack to create valid responses
t.Skip("Skipping response test - requires full SIP stack initialization")
conf := OutboundConfig{
Transport: "udp",
Via: "192.168.1.100:5060",
From: "34020000001320000001",
To: "34020000001110000001",
}
req, err := NewRequest(sip.INVITE, nil, conf)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
tests := []struct {
name string
code sip.StatusCode
reason string
}{
{
name: "100 Trying",
code: sip.StatusTrying,
reason: "Trying",
},
{
name: "180 Ringing",
code: sip.StatusRinging,
reason: "Ringing",
},
{
name: "200 OK",
code: sip.StatusOK,
reason: "OK",
},
{
name: "404 Not Found",
code: sip.StatusNotFound,
reason: "Not Found",
},
{
name: "500 Server Internal Error",
code: sip.StatusInternalServerError,
reason: "Server Internal Error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := newResponse(req, tt.code, tt.reason)
if resp == nil {
t.Fatal("Expected response to be non-nil")
}
if resp.StatusCode != tt.code {
t.Errorf("Expected status code %d, got %d", tt.code, resp.StatusCode)
}
if resp.Reason != tt.reason {
t.Errorf("Expected reason '%s', got '%s'", tt.reason, resp.Reason)
}
// Check that To header has tag
to := resp.To()
if to == nil {
t.Error("Expected To header to be set")
} else {
tag, ok := to.Params.Get("tag")
if !ok || tag == "" {
t.Error("Expected To header to have tag parameter")
}
// Check tag length is 10
if len(tag) != 10 {
t.Errorf("Expected tag length to be 10, got %d", len(tag))
}
}
// Check that Allow header is removed
allow := resp.GetHeader("Allow")
if allow != nil {
t.Error("Expected Allow header to be removed")
}
})
}
}
func TestResponseConstants(t *testing.T) {
if TIME_LAYOUT != "2024-01-01T00:00:00" {
t.Errorf("Expected TIME_LAYOUT to be '2024-01-01T00:00:00', got '%s'", TIME_LAYOUT)
}
if EXPIRES_TIME != 3600 {
t.Errorf("Expected EXPIRES_TIME to be 3600, got %d", EXPIRES_TIME)
}
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@ -1,122 +1,122 @@
package service package service
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"github.com/emiago/sipgo" "github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip" "github.com/emiago/sipgo/sip"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/srs-sip/pkg/config" "github.com/ossrs/srs-sip/pkg/config"
"github.com/ossrs/srs-sip/pkg/service/stack" "github.com/ossrs/srs-sip/pkg/service/stack"
) )
const ( const (
UserAgent = "SRS-SIP/1.0" UserAgent = "SRS-SIP/1.0"
) )
type UAC struct { type UAC struct {
*Cascade *Cascade
SN uint32 SN uint32
LocalIP string LocalIP string
} }
func NewUac() *UAC { func NewUac() *UAC {
ip, err := config.GetLocalIP() ip, err := config.GetLocalIP()
if err != nil { if err != nil {
slog.Error("get local ip failed", "error", err) slog.Error("get local ip failed", "error", err)
return nil return nil
} }
c := &UAC{ c := &UAC{
Cascade: &Cascade{}, Cascade: &Cascade{},
LocalIP: ip, LocalIP: ip,
} }
return c return c
} }
func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error { func (c *UAC) Start(agent *sipgo.UserAgent, r0 interface{}) error {
var err error var err error
c.ctx = context.Background() c.ctx = context.Background()
c.conf = r0.(*config.MainConfig) c.conf = r0.(*config.MainConfig)
if agent == nil { if agent == nil {
ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent)) ua, err := sipgo.NewUA(sipgo.WithUserAgent(UserAgent))
if err != nil { if err != nil {
return err return err
} }
agent = ua agent = ua
} }
c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP)) c.sipCli, err = sipgo.NewClient(agent, sipgo.WithClientHostname(c.LocalIP))
if err != nil { if err != nil {
return err return err
} }
c.sipSvr, err = sipgo.NewServer(agent) c.sipSvr, err = sipgo.NewServer(agent)
if err != nil { if err != nil {
return err return err
} }
c.sipSvr.OnInvite(c.onInvite) c.sipSvr.OnInvite(c.onInvite)
c.sipSvr.OnBye(c.onBye) c.sipSvr.OnBye(c.onBye)
c.sipSvr.OnMessage(c.onMessage) c.sipSvr.OnMessage(c.onMessage)
go c.doRegister() go c.doRegister()
return nil return nil
} }
func (c *UAC) Stop() { func (c *UAC) Stop() {
// TODO: 断开所有当前连接 // TODO: 断开所有当前连接
c.sipCli.Close() c.sipCli.Close()
c.sipSvr.Close() c.sipSvr.Close()
} }
func (c *UAC) doRegister() error { func (c *UAC) doRegister() error {
r, _ := stack.NewRegisterRequest(stack.OutboundConfig{ r, _ := stack.NewRegisterRequest(stack.OutboundConfig{
From: "34020000001110000001", From: "34020000001110000001",
To: "34020000002000000001", To: "34020000002000000001",
Transport: "UDP", Transport: "UDP",
Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port), Via: fmt.Sprintf("%s:%d", c.LocalIP, c.conf.GB28181.Port),
}) })
tx, err := c.sipCli.TransactionRequest(c.ctx, r) tx, err := c.sipCli.TransactionRequest(c.ctx, r)
if err != nil { if err != nil {
return errors.Wrapf(err, "transaction request error") return errors.Wrapf(err, "transaction request error")
} }
rs, _ := c.getResponse(tx) rs, _ := c.getResponse(tx)
slog.Info("register response", "response", rs.String()) slog.Info("register response", "response", rs.String())
return nil return nil
} }
func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) OnRequest(req *sip.Request, tx sip.ServerTransaction) {
switch req.Method { switch req.Method {
case "INVITE": case "INVITE":
c.onInvite(req, tx) c.onInvite(req, tx)
} }
} }
func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onInvite(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onInvite") slog.Debug("onInvite")
} }
func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onBye(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onBye") slog.Debug("onBye")
} }
func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) { func (c *UAC) onMessage(req *sip.Request, tx sip.ServerTransaction) {
slog.Debug("onMessage", "request", req.String()) slog.Debug("onMessage", "request", req.String())
} }
func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) { func (c *UAC) getResponse(tx sip.ClientTransaction) (*sip.Response, error) {
select { select {
case <-tx.Done(): case <-tx.Done():
return nil, fmt.Errorf("transaction died") return nil, fmt.Errorf("transaction died")
case res := <-tx.Responses(): case res := <-tx.Responses():
return res, nil return res, nil
} }
} }

View File

@ -1,201 +1,201 @@
package utils package utils
import ( import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
) )
var logLevelMap = map[string]slog.Level{ var logLevelMap = map[string]slog.Level{
"debug": slog.LevelDebug, "debug": slog.LevelDebug,
"info": slog.LevelInfo, "info": slog.LevelInfo,
"warn": slog.LevelWarn, "warn": slog.LevelWarn,
"error": slog.LevelError, "error": slog.LevelError,
} }
// 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志 // 自定义格式处理器,以 [时间] [级别] [消息] 格式输出日志
type CustomFormatHandler struct { type CustomFormatHandler struct {
mu sync.Mutex mu sync.Mutex
w io.Writer w io.Writer
level slog.Level level slog.Level
attrs []slog.Attr attrs []slog.Attr
groups []string groups []string
} }
// NewCustomFormatHandler 创建一个新的自定义格式处理器 // NewCustomFormatHandler 创建一个新的自定义格式处理器
func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler { func NewCustomFormatHandler(w io.Writer, opts *slog.HandlerOptions) *CustomFormatHandler {
if opts == nil { if opts == nil {
opts = &slog.HandlerOptions{} opts = &slog.HandlerOptions{}
} }
// 获取日志级别如果opts.Level是nil则默认为Info // 获取日志级别如果opts.Level是nil则默认为Info
var level slog.Level var level slog.Level
if opts.Level != nil { if opts.Level != nil {
level = opts.Level.Level() level = opts.Level.Level()
} else { } else {
level = slog.LevelInfo level = slog.LevelInfo
} }
return &CustomFormatHandler{ return &CustomFormatHandler{
w: w, w: w,
level: level, level: level,
} }
} }
// Enabled 实现 slog.Handler 接口 // Enabled 实现 slog.Handler 接口
func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool { func (h *CustomFormatHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= h.level return level >= h.level
} }
// Handle 实现 slog.Handler 接口,以自定义格式输出日志 // Handle 实现 slog.Handler 接口,以自定义格式输出日志
func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error { func (h *CustomFormatHandler) Handle(ctx context.Context, record slog.Record) error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
// 时间格式 // 时间格式
timeStr := record.Time.Format("2006-01-02 15:04:05.000") timeStr := record.Time.Format("2006-01-02 15:04:05.000")
// 日志级别 // 日志级别
var levelStr string var levelStr string
switch { switch {
case record.Level >= slog.LevelError: case record.Level >= slog.LevelError:
levelStr = "ERROR" levelStr = "ERROR"
case record.Level >= slog.LevelWarn: case record.Level >= slog.LevelWarn:
levelStr = "WARN " levelStr = "WARN "
case record.Level >= slog.LevelInfo: case record.Level >= slog.LevelInfo:
levelStr = "INFO " levelStr = "INFO "
default: default:
levelStr = "DEBUG" levelStr = "DEBUG"
} }
// 构建日志行 // 构建日志行
logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message) logLine := fmt.Sprintf("[%s] [%s] %s", timeStr, levelStr, record.Message)
// 处理其他属性 // 处理其他属性
var attrs []string var attrs []string
record.Attrs(func(attr slog.Attr) bool { record.Attrs(func(attr slog.Attr) bool {
attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value)) attrs = append(attrs, fmt.Sprintf("%s=%v", attr.Key, attr.Value))
return true return true
}) })
if len(attrs) > 0 { if len(attrs) > 0 {
logLine += " " + strings.Join(attrs, " ") logLine += " " + strings.Join(attrs, " ")
} }
// 写入日志 // 写入日志
_, err := fmt.Fprintln(h.w, logLine) _, err := fmt.Fprintln(h.w, logLine)
return err return err
} }
// WithAttrs 实现 slog.Handler 接口 // WithAttrs 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler { func (h *CustomFormatHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
h2 := *h h2 := *h
h2.attrs = append(h.attrs[:], attrs...) h2.attrs = append(h.attrs[:], attrs...)
return &h2 return &h2
} }
// WithGroup 实现 slog.Handler 接口 // WithGroup 实现 slog.Handler 接口
func (h *CustomFormatHandler) WithGroup(name string) slog.Handler { func (h *CustomFormatHandler) WithGroup(name string) slog.Handler {
h2 := *h h2 := *h
h2.groups = append(h.groups[:], name) h2.groups = append(h.groups[:], name)
return &h2 return &h2
} }
// MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器 // MultiHandler 实现了 slog.Handler 接口,将日志同时发送到多个处理器
type MultiHandler struct { type MultiHandler struct {
handlers []slog.Handler handlers []slog.Handler
} }
// Enabled 实现 slog.Handler 接口 // Enabled 实现 slog.Handler 接口
func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool {
// 如果任何一个处理器启用了该级别,则返回 true // 如果任何一个处理器启用了该级别,则返回 true
for _, handler := range h.handlers { for _, handler := range h.handlers {
if handler.Enabled(ctx, level) { if handler.Enabled(ctx, level) {
return true return true
} }
} }
return false return false
} }
// Handle 实现 slog.Handler 接口 // Handle 实现 slog.Handler 接口
func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// 将记录发送到所有处理器 // 将记录发送到所有处理器
for _, handler := range h.handlers { for _, handler := range h.handlers {
if handler.Enabled(ctx, record.Level) { if handler.Enabled(ctx, record.Level) {
if err := handler.Handle(ctx, record); err != nil { if err := handler.Handle(ctx, record); err != nil {
return err return err
} }
} }
} }
return nil return nil
} }
// WithAttrs 实现 slog.Handler 接口 // WithAttrs 实现 slog.Handler 接口
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers)) newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers { for i, handler := range h.handlers {
newHandlers[i] = handler.WithAttrs(attrs) newHandlers[i] = handler.WithAttrs(attrs)
} }
return &MultiHandler{handlers: newHandlers} return &MultiHandler{handlers: newHandlers}
} }
// WithGroup 实现 slog.Handler 接口 // WithGroup 实现 slog.Handler 接口
func (h *MultiHandler) WithGroup(name string) slog.Handler { func (h *MultiHandler) WithGroup(name string) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers)) newHandlers := make([]slog.Handler, len(h.handlers))
for i, handler := range h.handlers { for i, handler := range h.handlers {
newHandlers[i] = handler.WithGroup(name) newHandlers[i] = handler.WithGroup(name)
} }
return &MultiHandler{handlers: newHandlers} return &MultiHandler{handlers: newHandlers}
} }
// SetupLogger 设置日志输出 // SetupLogger 设置日志输出
func SetupLogger(logLevel string, logFile string) error { func SetupLogger(logLevel string, logFile string) error {
// 创建标准错误输出的处理器,使用自定义格式 // 创建标准错误输出的处理器,使用自定义格式
stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ stdHandler := NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevelMap[logLevel], Level: logLevelMap[logLevel],
}) })
// 如果没有指定日志文件,则仅使用标准错误处理器 // 如果没有指定日志文件,则仅使用标准错误处理器
if logFile == "" { if logFile == "" {
slog.SetDefault(slog.New(stdHandler)) slog.SetDefault(slog.New(stdHandler))
return nil return nil
} }
// 确保日志文件所在目录存在 // 确保日志文件所在目录存在
logDir := filepath.Dir(logFile) logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0755); err != nil { if err := os.MkdirAll(logDir, 0755); err != nil {
return err return err
} }
// 打开日志文件,如果不存在则创建,追加写入模式 // 打开日志文件,如果不存在则创建,追加写入模式
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil { if err != nil {
return err return err
} }
// 创建文件输出的处理器,使用自定义格式 // 创建文件输出的处理器,使用自定义格式
fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{ fileHandler := NewCustomFormatHandler(file, &slog.HandlerOptions{
Level: logLevelMap[logLevel], Level: logLevelMap[logLevel],
}) })
// 创建多输出处理器 // 创建多输出处理器
multiHandler := &MultiHandler{ multiHandler := &MultiHandler{
handlers: []slog.Handler{stdHandler, fileHandler}, handlers: []slog.Handler{stdHandler, fileHandler},
} }
// 设置全局日志处理器 // 设置全局日志处理器
slog.SetDefault(slog.New(multiHandler)) slog.SetDefault(slog.New(multiHandler))
return nil return nil
} }
// InitDefaultLogger 初始化默认日志处理器 // InitDefaultLogger 初始化默认日志处理器
func InitDefaultLogger(level slog.Level) { func InitDefaultLogger(level slog.Level) {
slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{ slog.SetDefault(slog.New(NewCustomFormatHandler(os.Stderr, &slog.HandlerOptions{
Level: level, Level: level,
}))) })))
} }

197
pkg/utils/utils_test.go Normal file
View File

@ -0,0 +1,197 @@
package utils
import (
"testing"
)
func TestGenRandomNumber(t *testing.T) {
tests := []struct {
name string
length int
}{
{"Generate 1 digit", 1},
{"Generate 5 digits", 5},
{"Generate 9 digits", 9},
{"Generate 10 digits", 10},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenRandomNumber(tt.length)
// 验证长度
if len(result) != tt.length {
t.Errorf("Expected length %d, got %d", tt.length, len(result))
}
// 验证所有字符都是数字
for i, c := range result {
if c < '0' || c > '9' {
t.Errorf("Character at position %d is not a digit: %c", i, c)
}
}
})
}
}
func TestGenRandomNumberUniqueness(t *testing.T) {
// 生成多个随机数,验证它们不完全相同(虽然理论上可能相同,但概率极低)
results := make(map[string]bool)
iterations := 100
length := 10
for i := 0; i < iterations; i++ {
result := GenRandomNumber(length)
results[result] = true
}
// 至少应该有一些不同的值不太可能100次都生成相同的10位数
if len(results) < 50 {
t.Errorf("Expected at least 50 unique values out of %d iterations, got %d", iterations, len(results))
}
}
func TestCreateSSRC(t *testing.T) {
tests := []struct {
name string
isLive bool
expected byte
}{
{"Live stream SSRC", true, '0'},
{"Non-live stream SSRC", false, '1'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ssrc := CreateSSRC(tt.isLive)
// 验证长度为10
if len(ssrc) != 10 {
t.Errorf("Expected SSRC length 10, got %d", len(ssrc))
}
// 验证第一个字符
if ssrc[0] != tt.expected {
t.Errorf("Expected first character '%c', got '%c'", tt.expected, ssrc[0])
}
// 验证所有字符都是数字
for i, c := range ssrc {
if c < '0' || c > '9' {
t.Errorf("Character at position %d is not a digit: %c", i, c)
}
}
})
}
}
func TestCreateSSRCUniqueness(t *testing.T) {
// 测试生成的 SSRC 具有唯一性
results := make(map[string]bool)
iterations := 100
for i := 0; i < iterations; i++ {
ssrc := CreateSSRC(true)
results[ssrc] = true
}
// 应该有很多不同的值
if len(results) < 50 {
t.Errorf("Expected at least 50 unique SSRCs out of %d iterations, got %d", iterations, len(results))
}
}
func TestIsVideoChannel(t *testing.T) {
tests := []struct {
name string
channelID string
expected bool
}{
{
name: "Video channel type 131",
channelID: "34020000001310000001",
expected: true,
},
{
name: "Video channel type 132",
channelID: "34020000001320000001",
expected: true,
},
{
name: "Audio channel type 137",
channelID: "34020000001370000001",
expected: false,
},
{
name: "Alarm channel type 134",
channelID: "34020000001340000001",
expected: false,
},
{
name: "Other device type",
channelID: "34020000001110000001",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsVideoChannel(tt.channelID)
if result != tt.expected {
t.Errorf("IsVideoChannel(%s) = %v, expected %v", tt.channelID, result, tt.expected)
}
})
}
}
func TestGetSessionName(t *testing.T) {
tests := []struct {
name string
playType int
expected string
}{
{"Live play", 0, "Play"},
{"Playback", 1, "Playback"},
{"Download", 2, "Download"},
{"Talk", 3, "Talk"},
{"Unknown type", 99, "Play"},
{"Negative type", -1, "Play"},
{"Type 4", 4, "Play"},
{"Type 5", 5, "Play"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSessionName(tt.playType)
if result != tt.expected {
t.Errorf("GetSessionName(%d) = %s, expected %s", tt.playType, result, tt.expected)
}
})
}
}
func TestGenRandomNumberZeroLength(t *testing.T) {
result := GenRandomNumber(0)
if len(result) != 0 {
t.Errorf("Expected empty string for length 0, got %s", result)
}
}
func TestCreateSSRCBothTypes(t *testing.T) {
// Test both live and non-live in one test
liveSSRC := CreateSSRC(true)
nonLiveSSRC := CreateSSRC(false)
if liveSSRC[0] != '0' {
t.Errorf("Live SSRC should start with '0', got '%c'", liveSSRC[0])
}
if nonLiveSSRC[0] != '1' {
t.Errorf("Non-live SSRC should start with '1', got '%c'", nonLiveSSRC[0])
}
// They should be different (with very high probability)
if liveSSRC == nonLiveSSRC {
t.Error("Live and non-live SSRCs should be different")
}
}

View File

@ -1,30 +1,30 @@
package main package main
import ( import (
"context" "context"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/ossrs/go-oryx-lib/logger" "github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181" "github.com/ossrs/srs-bench/gb28181"
) )
func main() { func main() {
ctx := context.Background() ctx := context.Background()
var conf interface{} var conf interface{}
conf = gb28181.Parse(ctx) conf = gb28181.Parse(ctx)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
go func() { go func() {
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
for sig := range sigs { for sig := range sigs {
logger.Wf(ctx, "Quit for signal %v", sig) logger.Wf(ctx, "Quit for signal %v", sig)
cancel() cancel()
} }
}() }()
gb28181.Run(ctx, conf) gb28181.Run(ctx, conf)
} }