Files
2026-05-11 08:51:53 +08:00

324 lines
7.1 KiB
Go

package main
import (
"database/sql"
"encoding/hex"
"flag"
"fmt"
"hr_receiver/config"
"io"
"os"
"path/filepath"
"strings"
"time"
)
var defaultTables = []string{"users", "user_region_bindings", "kindergartens"}
func main() {
var (
tablesArg string
outputArg string
schemaArg string
)
flag.StringVar(&tablesArg, "tables", "", "Comma-separated table names to export. Default: users,user_region_bindings,kindergartens")
flag.StringVar(&outputArg, "output", "", "Output .sql file path. Default: export_<timestamp>.sql")
flag.StringVar(&schemaArg, "schema", "public", "Database schema name")
flag.Parse()
tables := resolveTables(tablesArg, flag.Args())
if len(tables) == 0 {
fmt.Fprintln(os.Stderr, "no tables specified")
os.Exit(1)
}
config.InitConfig()
config.ConnectDB()
outputPath := outputArg
if outputPath == "" {
outputPath = fmt.Sprintf("export_%s.sql", time.Now().Format("20060102_150405"))
}
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil && filepath.Dir(outputPath) != "." {
fmt.Fprintf(os.Stderr, "create output directory failed: %v\n", err)
os.Exit(1)
}
file, err := os.Create(outputPath)
if err != nil {
fmt.Fprintf(os.Stderr, "create output file failed: %v\n", err)
os.Exit(1)
}
defer file.Close()
if err := exportTables(file, schemaArg, tables); err != nil {
fmt.Fprintf(os.Stderr, "export failed: %v\n", err)
os.Exit(1)
}
fmt.Printf("exported %d table(s) to %s\n", len(tables), outputPath)
}
func resolveTables(tablesArg string, positional []string) []string {
if len(positional) > 0 {
return dedupeNonEmpty(positional)
}
if strings.TrimSpace(tablesArg) == "" {
return append([]string(nil), defaultTables...)
}
return dedupeNonEmpty(strings.Split(tablesArg, ","))
}
func dedupeNonEmpty(items []string) []string {
result := make([]string, 0, len(items))
seen := make(map[string]struct{}, len(items))
for _, item := range items {
value := strings.TrimSpace(item)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
result = append(result, value)
}
return result
}
func exportTables(w io.Writer, schema string, tables []string) error {
db, err := config.DB.DB()
if err != nil {
return err
}
header := []string{
"-- Generated by cmd/export_sql",
fmt.Sprintf("-- Time: %s", time.Now().Format(time.RFC3339)),
fmt.Sprintf("-- Schema: %s", schema),
fmt.Sprintf("-- Tables: %s", strings.Join(tables, ", ")),
"",
}
if _, err := io.WriteString(w, strings.Join(header, "\n")); err != nil {
return err
}
for _, table := range tables {
columns, err := getColumns(db, schema, table)
if err != nil {
return fmt.Errorf("read columns for %s: %w", table, err)
}
if len(columns) == 0 {
return fmt.Errorf("table %s not found or has no columns", table)
}
orderBy, err := getPrimaryKeyColumns(db, schema, table)
if err != nil {
return fmt.Errorf("read primary key for %s: %w", table, err)
}
if _, err := fmt.Fprintf(w, "-- Table: %s\n", table); err != nil {
return err
}
if err := exportTableData(w, db, schema, table, columns, orderBy); err != nil {
return err
}
if _, err := io.WriteString(w, "\n"); err != nil {
return err
}
}
return nil
}
func getColumns(db *sql.DB, schema, table string) ([]string, error) {
const query = `
SELECT column_name
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
`
rows, err := db.Query(query, schema, table)
if err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var column string
if err := rows.Scan(&column); err != nil {
return nil, err
}
columns = append(columns, column)
}
return columns, rows.Err()
}
func getPrimaryKeyColumns(db *sql.DB, schema, table string) ([]string, error) {
const query = `
SELECT a.attname
FROM pg_index i
JOIN pg_class c ON c.oid = i.indrelid
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey)
WHERE i.indisprimary
AND n.nspname = $1
AND c.relname = $2
ORDER BY array_position(i.indkey, a.attnum)
`
rows, err := db.Query(query, schema, table)
if err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var column string
if err := rows.Scan(&column); err != nil {
return nil, err
}
columns = append(columns, column)
}
return columns, rows.Err()
}
func exportTableData(w io.Writer, db *sql.DB, schema, table string, columns, orderBy []string) error {
query := fmt.Sprintf(
`SELECT %s FROM %s.%s`,
joinIdentifiers(columns),
quoteIdentifier(schema),
quoteIdentifier(table),
)
if len(orderBy) > 0 {
query += " ORDER BY " + joinIdentifiers(orderBy)
}
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("query rows for %s: %w", table, err)
}
defer rows.Close()
values := make([]any, len(columns))
scanTargets := make([]any, len(columns))
for i := range values {
scanTargets[i] = &values[i]
}
rowCount := 0
for rows.Next() {
clear(values)
if err := rows.Scan(scanTargets...); err != nil {
return fmt.Errorf("scan row for %s: %w", table, err)
}
literals := make([]string, len(columns))
for i, value := range values {
literals[i] = toSQLLiteral(value)
}
stmt := fmt.Sprintf(
"INSERT INTO %s.%s (%s) VALUES (%s);\n",
quoteIdentifier(schema),
quoteIdentifier(table),
joinIdentifiers(columns),
strings.Join(literals, ", "),
)
if _, err := io.WriteString(w, stmt); err != nil {
return err
}
rowCount++
}
if err := rows.Err(); err != nil {
return err
}
if rowCount == 0 {
_, err := io.WriteString(w, "-- no rows\n")
return err
}
return nil
}
func joinIdentifiers(columns []string) string {
quoted := make([]string, len(columns))
for i, column := range columns {
quoted[i] = quoteIdentifier(column)
}
return strings.Join(quoted, ", ")
}
func quoteIdentifier(value string) string {
return `"` + strings.ReplaceAll(value, `"`, `""`) + `"`
}
func toSQLLiteral(value any) string {
switch v := value.(type) {
case nil:
return "NULL"
case string:
return quoteString(v)
case []byte:
if len(v) == 0 {
return quoteString("")
}
if isPrintableUTF8(v) {
return quoteString(string(v))
}
return quoteString(`\x` + hex.EncodeToString(v))
case bool:
if v {
return "TRUE"
}
return "FALSE"
case int:
return fmt.Sprintf("%d", v)
case int8:
return fmt.Sprintf("%d", v)
case int16:
return fmt.Sprintf("%d", v)
case int32:
return fmt.Sprintf("%d", v)
case int64:
return fmt.Sprintf("%d", v)
case uint:
return fmt.Sprintf("%d", v)
case uint8:
return fmt.Sprintf("%d", v)
case uint16:
return fmt.Sprintf("%d", v)
case uint32:
return fmt.Sprintf("%d", v)
case uint64:
return fmt.Sprintf("%d", v)
case float32:
return fmt.Sprintf("%g", v)
case float64:
return fmt.Sprintf("%g", v)
case time.Time:
return quoteString(v.Format(time.RFC3339Nano))
default:
return quoteString(fmt.Sprint(v))
}
}
func quoteString(value string) string {
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
}
func isPrintableUTF8(data []byte) bool {
for _, b := range data {
if b == 0 {
return false
}
}
return true
}