324 lines
7.1 KiB
Go
324 lines
7.1 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"flag"
|
|
"fmt"
|
|
"hr_receiver/config"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var defaultTables = []string{"users", "user_region_bindings", "kindergartens"}
|
|
|
|
func main() {
|
|
var (
|
|
tablesArg string
|
|
outputArg string
|
|
schemaArg string
|
|
)
|
|
|
|
flag.StringVar(&tablesArg, "tables", "", "Comma-separated table names to export. Default: users,user_region_bindings,kindergartens")
|
|
flag.StringVar(&outputArg, "output", "", "Output .sql file path. Default: export_<timestamp>.sql")
|
|
flag.StringVar(&schemaArg, "schema", "public", "Database schema name")
|
|
flag.Parse()
|
|
|
|
tables := resolveTables(tablesArg, flag.Args())
|
|
if len(tables) == 0 {
|
|
fmt.Fprintln(os.Stderr, "no tables specified")
|
|
os.Exit(1)
|
|
}
|
|
|
|
config.InitConfig()
|
|
config.ConnectDB()
|
|
|
|
outputPath := outputArg
|
|
if outputPath == "" {
|
|
outputPath = fmt.Sprintf("export_%s.sql", time.Now().Format("20060102_150405"))
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil && filepath.Dir(outputPath) != "." {
|
|
fmt.Fprintf(os.Stderr, "create output directory failed: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
file, err := os.Create(outputPath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "create output file failed: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer file.Close()
|
|
|
|
if err := exportTables(file, schemaArg, tables); err != nil {
|
|
fmt.Fprintf(os.Stderr, "export failed: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
fmt.Printf("exported %d table(s) to %s\n", len(tables), outputPath)
|
|
}
|
|
|
|
func resolveTables(tablesArg string, positional []string) []string {
|
|
if len(positional) > 0 {
|
|
return dedupeNonEmpty(positional)
|
|
}
|
|
if strings.TrimSpace(tablesArg) == "" {
|
|
return append([]string(nil), defaultTables...)
|
|
}
|
|
return dedupeNonEmpty(strings.Split(tablesArg, ","))
|
|
}
|
|
|
|
func dedupeNonEmpty(items []string) []string {
|
|
result := make([]string, 0, len(items))
|
|
seen := make(map[string]struct{}, len(items))
|
|
for _, item := range items {
|
|
value := strings.TrimSpace(item)
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[value]; ok {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
result = append(result, value)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func exportTables(w io.Writer, schema string, tables []string) error {
|
|
db, err := config.DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
header := []string{
|
|
"-- Generated by cmd/export_sql",
|
|
fmt.Sprintf("-- Time: %s", time.Now().Format(time.RFC3339)),
|
|
fmt.Sprintf("-- Schema: %s", schema),
|
|
fmt.Sprintf("-- Tables: %s", strings.Join(tables, ", ")),
|
|
"",
|
|
}
|
|
if _, err := io.WriteString(w, strings.Join(header, "\n")); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, table := range tables {
|
|
columns, err := getColumns(db, schema, table)
|
|
if err != nil {
|
|
return fmt.Errorf("read columns for %s: %w", table, err)
|
|
}
|
|
if len(columns) == 0 {
|
|
return fmt.Errorf("table %s not found or has no columns", table)
|
|
}
|
|
|
|
orderBy, err := getPrimaryKeyColumns(db, schema, table)
|
|
if err != nil {
|
|
return fmt.Errorf("read primary key for %s: %w", table, err)
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(w, "-- Table: %s\n", table); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := exportTableData(w, db, schema, table, columns, orderBy); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := io.WriteString(w, "\n"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getColumns(db *sql.DB, schema, table string) ([]string, error) {
|
|
const query = `
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = $1 AND table_name = $2
|
|
ORDER BY ordinal_position
|
|
`
|
|
rows, err := db.Query(query, schema, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var columns []string
|
|
for rows.Next() {
|
|
var column string
|
|
if err := rows.Scan(&column); err != nil {
|
|
return nil, err
|
|
}
|
|
columns = append(columns, column)
|
|
}
|
|
return columns, rows.Err()
|
|
}
|
|
|
|
func getPrimaryKeyColumns(db *sql.DB, schema, table string) ([]string, error) {
|
|
const query = `
|
|
SELECT a.attname
|
|
FROM pg_index i
|
|
JOIN pg_class c ON c.oid = i.indrelid
|
|
JOIN pg_namespace n ON n.oid = c.relnamespace
|
|
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey)
|
|
WHERE i.indisprimary
|
|
AND n.nspname = $1
|
|
AND c.relname = $2
|
|
ORDER BY array_position(i.indkey, a.attnum)
|
|
`
|
|
rows, err := db.Query(query, schema, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var columns []string
|
|
for rows.Next() {
|
|
var column string
|
|
if err := rows.Scan(&column); err != nil {
|
|
return nil, err
|
|
}
|
|
columns = append(columns, column)
|
|
}
|
|
return columns, rows.Err()
|
|
}
|
|
|
|
func exportTableData(w io.Writer, db *sql.DB, schema, table string, columns, orderBy []string) error {
|
|
query := fmt.Sprintf(
|
|
`SELECT %s FROM %s.%s`,
|
|
joinIdentifiers(columns),
|
|
quoteIdentifier(schema),
|
|
quoteIdentifier(table),
|
|
)
|
|
if len(orderBy) > 0 {
|
|
query += " ORDER BY " + joinIdentifiers(orderBy)
|
|
}
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return fmt.Errorf("query rows for %s: %w", table, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
values := make([]any, len(columns))
|
|
scanTargets := make([]any, len(columns))
|
|
for i := range values {
|
|
scanTargets[i] = &values[i]
|
|
}
|
|
|
|
rowCount := 0
|
|
for rows.Next() {
|
|
clear(values)
|
|
if err := rows.Scan(scanTargets...); err != nil {
|
|
return fmt.Errorf("scan row for %s: %w", table, err)
|
|
}
|
|
|
|
literals := make([]string, len(columns))
|
|
for i, value := range values {
|
|
literals[i] = toSQLLiteral(value)
|
|
}
|
|
|
|
stmt := fmt.Sprintf(
|
|
"INSERT INTO %s.%s (%s) VALUES (%s);\n",
|
|
quoteIdentifier(schema),
|
|
quoteIdentifier(table),
|
|
joinIdentifiers(columns),
|
|
strings.Join(literals, ", "),
|
|
)
|
|
if _, err := io.WriteString(w, stmt); err != nil {
|
|
return err
|
|
}
|
|
rowCount++
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if rowCount == 0 {
|
|
_, err := io.WriteString(w, "-- no rows\n")
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func joinIdentifiers(columns []string) string {
|
|
quoted := make([]string, len(columns))
|
|
for i, column := range columns {
|
|
quoted[i] = quoteIdentifier(column)
|
|
}
|
|
return strings.Join(quoted, ", ")
|
|
}
|
|
|
|
func quoteIdentifier(value string) string {
|
|
return `"` + strings.ReplaceAll(value, `"`, `""`) + `"`
|
|
}
|
|
|
|
func toSQLLiteral(value any) string {
|
|
switch v := value.(type) {
|
|
case nil:
|
|
return "NULL"
|
|
case string:
|
|
return quoteString(v)
|
|
case []byte:
|
|
if len(v) == 0 {
|
|
return quoteString("")
|
|
}
|
|
if isPrintableUTF8(v) {
|
|
return quoteString(string(v))
|
|
}
|
|
return quoteString(`\x` + hex.EncodeToString(v))
|
|
case bool:
|
|
if v {
|
|
return "TRUE"
|
|
}
|
|
return "FALSE"
|
|
case int:
|
|
return fmt.Sprintf("%d", v)
|
|
case int8:
|
|
return fmt.Sprintf("%d", v)
|
|
case int16:
|
|
return fmt.Sprintf("%d", v)
|
|
case int32:
|
|
return fmt.Sprintf("%d", v)
|
|
case int64:
|
|
return fmt.Sprintf("%d", v)
|
|
case uint:
|
|
return fmt.Sprintf("%d", v)
|
|
case uint8:
|
|
return fmt.Sprintf("%d", v)
|
|
case uint16:
|
|
return fmt.Sprintf("%d", v)
|
|
case uint32:
|
|
return fmt.Sprintf("%d", v)
|
|
case uint64:
|
|
return fmt.Sprintf("%d", v)
|
|
case float32:
|
|
return fmt.Sprintf("%g", v)
|
|
case float64:
|
|
return fmt.Sprintf("%g", v)
|
|
case time.Time:
|
|
return quoteString(v.Format(time.RFC3339Nano))
|
|
default:
|
|
return quoteString(fmt.Sprint(v))
|
|
}
|
|
}
|
|
|
|
func quoteString(value string) string {
|
|
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
|
|
}
|
|
|
|
func isPrintableUTF8(data []byte) bool {
|
|
for _, b := range data {
|
|
if b == 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|