mysql_extend.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package util
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "strconv"
  8. "strings"
  9. )
  10. func ScanStruct(rows *sql.Rows, dest interface{}) error {
  11. rv := reflect.ValueOf(dest)
  12. if rv.Kind() != reflect.Ptr {
  13. return errors.New("must pass a pointer, not a value, to StructScan destination")
  14. }
  15. if rv.IsNil() {
  16. return errors.New("nil pointer passed to StructScan destination")
  17. }
  18. direct := rv.Elem()
  19. dtype := direct.Type()
  20. if dtype.Kind() != reflect.Slice {
  21. return fmt.Errorf("expected %s but got %s", reflect.Slice, dtype.Kind())
  22. }
  23. elemType := dtype.Elem()
  24. var (
  25. field reflect.StructField
  26. fmap = map[string]TagIndex{}
  27. dbtagVal string
  28. tagSplits []string
  29. fimap = map[TagIndex]int{}
  30. )
  31. cols, _ := rows.Columns()
  32. for i := 0; i < elemType.NumField(); i++ {
  33. field = elemType.Field(i)
  34. dbtagVal = field.Tag.Get("db")
  35. if dbtagVal == "null" {
  36. continue
  37. }
  38. tagIdx := TagIndex{}
  39. tagSplits = strings.Split(dbtagVal, ",")
  40. if len(tagSplits) > 1 {
  41. dbtagVal = tagSplits[0]
  42. if tagSplits[1] == "omitempty" {
  43. tagIdx.IsOmitEmpty = true
  44. }
  45. }
  46. if dbtagVal == "" {
  47. dbtagVal = field.Name
  48. }
  49. tagIdx.Idx = i
  50. fmap[dbtagVal] = tagIdx
  51. }
  52. var (
  53. tempStr sql.NullString
  54. valArr []interface{}
  55. )
  56. for rows.Next() {
  57. nelemVal := reflect.New(elemType)
  58. valArr = make([]interface{}, len(cols))
  59. for i, col := range cols {
  60. if fieldIdx, ok := fmap[col]; !ok {
  61. valArr[i] = new(interface{})
  62. } else {
  63. field = elemType.Field(fieldIdx.Idx)
  64. if fieldIdx.IsOmitEmpty {
  65. valArr[i] = &sql.NullString{}
  66. fimap[fieldIdx] = i
  67. } else {
  68. valArr[i] = nelemVal.Elem().Field(fieldIdx.Idx).Addr().Interface()
  69. }
  70. }
  71. }
  72. rows.Scan(valArr...)
  73. for k, v := range fimap {
  74. fval := reflect.ValueOf(valArr[v])
  75. if fval.IsValid() {
  76. field := elemType.Field(k.Idx)
  77. dstFval := reflect.New(field.Type)
  78. tempStr = fval.Elem().Interface().(sql.NullString)
  79. if !tempStr.Valid {
  80. continue
  81. }
  82. switch field.Type.Kind() {
  83. case reflect.Float64:
  84. float, _ := strconv.ParseFloat(tempStr.String, 64)
  85. float = ToFix(float, 2)
  86. dstFval.Elem().Set(reflect.ValueOf(float))
  87. case reflect.String:
  88. dstFval.Elem().Set(reflect.ValueOf(tempStr.String))
  89. case reflect.Int:
  90. intVal, _ := strconv.ParseInt(tempStr.String, 10, 0)
  91. dstFval.Elem().Set(reflect.ValueOf(int(intVal)))
  92. }
  93. //fmt.Println(dstFval.Elem().Interface())
  94. nelemVal.Elem().Field(k.Idx).Set(dstFval.Elem())
  95. }
  96. }
  97. direct.Set(reflect.Append(direct, nelemVal.Elem()))
  98. }
  99. return nil
  100. }