package main import ( "database/sql" "encoding/hex" "flag" "fmt" "hr_receiver/config" "io" "os" "path/filepath" "strings" "time" ) var defaultTables = []string{"users", "user_region_bindings", "kindergartens"} func main() { var ( tablesArg string outputArg string schemaArg string ) flag.StringVar(&tablesArg, "tables", "", "Comma-separated table names to export. Default: users,user_region_bindings,kindergartens") flag.StringVar(&outputArg, "output", "", "Output .sql file path. Default: export_.sql") flag.StringVar(&schemaArg, "schema", "public", "Database schema name") flag.Parse() tables := resolveTables(tablesArg, flag.Args()) if len(tables) == 0 { fmt.Fprintln(os.Stderr, "no tables specified") os.Exit(1) } config.InitConfig() config.ConnectDB() outputPath := outputArg if outputPath == "" { outputPath = fmt.Sprintf("export_%s.sql", time.Now().Format("20060102_150405")) } if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil && filepath.Dir(outputPath) != "." { fmt.Fprintf(os.Stderr, "create output directory failed: %v\n", err) os.Exit(1) } file, err := os.Create(outputPath) if err != nil { fmt.Fprintf(os.Stderr, "create output file failed: %v\n", err) os.Exit(1) } defer file.Close() if err := exportTables(file, schemaArg, tables); err != nil { fmt.Fprintf(os.Stderr, "export failed: %v\n", err) os.Exit(1) } fmt.Printf("exported %d table(s) to %s\n", len(tables), outputPath) } func resolveTables(tablesArg string, positional []string) []string { if len(positional) > 0 { return dedupeNonEmpty(positional) } if strings.TrimSpace(tablesArg) == "" { return append([]string(nil), defaultTables...) } return dedupeNonEmpty(strings.Split(tablesArg, ",")) } func dedupeNonEmpty(items []string) []string { result := make([]string, 0, len(items)) seen := make(map[string]struct{}, len(items)) for _, item := range items { value := strings.TrimSpace(item) if value == "" { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} result = append(result, value) } return result } func exportTables(w io.Writer, schema string, tables []string) error { db, err := config.DB.DB() if err != nil { return err } header := []string{ "-- Generated by cmd/export_sql", fmt.Sprintf("-- Time: %s", time.Now().Format(time.RFC3339)), fmt.Sprintf("-- Schema: %s", schema), fmt.Sprintf("-- Tables: %s", strings.Join(tables, ", ")), "", } if _, err := io.WriteString(w, strings.Join(header, "\n")); err != nil { return err } for _, table := range tables { columns, err := getColumns(db, schema, table) if err != nil { return fmt.Errorf("read columns for %s: %w", table, err) } if len(columns) == 0 { return fmt.Errorf("table %s not found or has no columns", table) } orderBy, err := getPrimaryKeyColumns(db, schema, table) if err != nil { return fmt.Errorf("read primary key for %s: %w", table, err) } if _, err := fmt.Fprintf(w, "-- Table: %s\n", table); err != nil { return err } if err := exportTableData(w, db, schema, table, columns, orderBy); err != nil { return err } if _, err := io.WriteString(w, "\n"); err != nil { return err } } return nil } func getColumns(db *sql.DB, schema, table string) ([]string, error) { const query = ` SELECT column_name FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ` rows, err := db.Query(query, schema, table) if err != nil { return nil, err } defer rows.Close() var columns []string for rows.Next() { var column string if err := rows.Scan(&column); err != nil { return nil, err } columns = append(columns, column) } return columns, rows.Err() } func getPrimaryKeyColumns(db *sql.DB, schema, table string) ([]string, error) { const query = ` SELECT a.attname FROM pg_index i JOIN pg_class c ON c.oid = i.indrelid JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey) WHERE i.indisprimary AND n.nspname = $1 AND c.relname = $2 ORDER BY array_position(i.indkey, a.attnum) ` rows, err := db.Query(query, schema, table) if err != nil { return nil, err } defer rows.Close() var columns []string for rows.Next() { var column string if err := rows.Scan(&column); err != nil { return nil, err } columns = append(columns, column) } return columns, rows.Err() } func exportTableData(w io.Writer, db *sql.DB, schema, table string, columns, orderBy []string) error { query := fmt.Sprintf( `SELECT %s FROM %s.%s`, joinIdentifiers(columns), quoteIdentifier(schema), quoteIdentifier(table), ) if len(orderBy) > 0 { query += " ORDER BY " + joinIdentifiers(orderBy) } rows, err := db.Query(query) if err != nil { return fmt.Errorf("query rows for %s: %w", table, err) } defer rows.Close() values := make([]any, len(columns)) scanTargets := make([]any, len(columns)) for i := range values { scanTargets[i] = &values[i] } rowCount := 0 for rows.Next() { clear(values) if err := rows.Scan(scanTargets...); err != nil { return fmt.Errorf("scan row for %s: %w", table, err) } literals := make([]string, len(columns)) for i, value := range values { literals[i] = toSQLLiteral(value) } stmt := fmt.Sprintf( "INSERT INTO %s.%s (%s) VALUES (%s);\n", quoteIdentifier(schema), quoteIdentifier(table), joinIdentifiers(columns), strings.Join(literals, ", "), ) if _, err := io.WriteString(w, stmt); err != nil { return err } rowCount++ } if err := rows.Err(); err != nil { return err } if rowCount == 0 { _, err := io.WriteString(w, "-- no rows\n") return err } return nil } func joinIdentifiers(columns []string) string { quoted := make([]string, len(columns)) for i, column := range columns { quoted[i] = quoteIdentifier(column) } return strings.Join(quoted, ", ") } func quoteIdentifier(value string) string { return `"` + strings.ReplaceAll(value, `"`, `""`) + `"` } func toSQLLiteral(value any) string { switch v := value.(type) { case nil: return "NULL" case string: return quoteString(v) case []byte: if len(v) == 0 { return quoteString("") } if isPrintableUTF8(v) { return quoteString(string(v)) } return quoteString(`\x` + hex.EncodeToString(v)) case bool: if v { return "TRUE" } return "FALSE" case int: return fmt.Sprintf("%d", v) case int8: return fmt.Sprintf("%d", v) case int16: return fmt.Sprintf("%d", v) case int32: return fmt.Sprintf("%d", v) case int64: return fmt.Sprintf("%d", v) case uint: return fmt.Sprintf("%d", v) case uint8: return fmt.Sprintf("%d", v) case uint16: return fmt.Sprintf("%d", v) case uint32: return fmt.Sprintf("%d", v) case uint64: return fmt.Sprintf("%d", v) case float32: return fmt.Sprintf("%g", v) case float64: return fmt.Sprintf("%g", v) case time.Time: return quoteString(v.Format(time.RFC3339Nano)) default: return quoteString(fmt.Sprint(v)) } } func quoteString(value string) string { return "'" + strings.ReplaceAll(value, "'", "''") + "'" } func isPrintableUTF8(data []byte) bool { for _, b := range data { if b == 0 { return false } } return true }