diff --git a/base_test.go b/base_test.go index 75575292..139b4529 100644 --- a/base_test.go +++ b/base_test.go @@ -538,6 +538,102 @@ func testCols(engine *Engine, t *testing.T) { fmt.Println(tmpUsers) } +type tempUser2 struct { + tempUser `xorm:"extends"` + Departname string +} + +func testExtends(engine *Engine, t *testing.T) { + err := engine.DropTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } + + tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} + _, err = engine.Insert(tu) + if err != nil { + t.Error(err) + panic(err) + } + + tu2 := &tempUser2{} + _, err = engine.Get(tu2) + if err != nil { + t.Error(err) + panic(err) + } + + tu3 := &tempUser2{tempUser{0, "extends update"}, ""} + _, err = engine.Id(tu2.Id).Update(tu3) + if err != nil { + t.Error(err) + panic(err) + } +} + +type allCols struct { + Bit int `xorm:"BIT"` + TinyInt int8 `xorm:"TINYINT"` + SmallInt int16 `xorm:"SMALLINT"` + MediumInt int32 `xorm:"MEDIUMINT"` + Int int `xorm:"INT"` + Integer int `xorm:"INTEGER"` + BigInt int64 `xorm:"BIGINT"` + + Char string `xorm:"CHAR(12)"` + Varchar string `xorm:"VARCHAR(54)"` + TinyText string `xorm:"TINYTEXT"` + Text string `xorm:"TEXT"` + MediumText string `xorm:"MEDIUMTEXT"` + LongText string `xorm:"LONGTEXT"` + Binary string `xorm:"BINARY"` + VarBinary string `xorm:"VARBINARY(12)"` + + Date time.Time `xorm:"DATE"` + DateTime time.Time `xorm:"DATETIME"` + Time time.Time `xorm:"TIME"` + TimeStamp time.Time `xorm:"TIMESTAMP"` + + Decimal float64 `xorm:"DECIMAL"` + Numeric float64 `xorm:"NUMERIC"` + + Real float32 `xorm:"REAL"` + Float float32 `xorm:"FLOAT"` + Double float64 `xorm:"DOUBLE"` + + TinyBlob []byte `xorm:"TINYBLOB"` + Blob []byte `xorm:"BLOB"` + MediumBlob []byte `xorm:"MEDIUMBLOB"` + LongBlob []byte `xorm:"LONGBLOB"` + Bytea []byte `xorm:"BYTEA"` + + Bool bool `xorm:"BOOL"` + + Serial int `xorm:"SERIAL"` + //BigSerial int64 `xorm:"BIGSERIAL"` +} + +func testColTypes(engine *Engine, t *testing.T) { + err := engine.DropTables(&allCols{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&allCols{}) + if err != nil { + t.Error(err) + panic(err) + } +} + func testTrans(engine *Engine, t *testing.T) { } @@ -571,4 +667,6 @@ func testAll(engine *Engine, t *testing.T) { testCols(engine, t) testCharst(engine, t) testStoreEngine(engine, t) + testExtends(engine, t) + testColTypes(engine, t) } diff --git a/engine.go b/engine.go index bcd30357..970dcd0b 100644 --- a/engine.go +++ b/engine.go @@ -212,110 +212,91 @@ func (engine *Engine) AutoMap(bean interface{}) *Table { func (engine *Engine) MapType(t reflect.Type) *Table { table := &Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t, Indexes: map[string][]string{}, Uniques: map[string][]string{}} - table.Columns = make(map[string]Column) + table.Columns = make(map[string]*Column) + var idFieldColName string for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag ormTagStr := tag.Get(engine.TagIdentifier) - var col Column + var col *Column fieldType := t.Field(i).Type if ormTagStr != "" { - col = Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, + col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, MapType: TWOSIDES} - ormTagStr = strings.ToLower(ormTagStr) tags := strings.Split(ormTagStr, " ") - // TODO: + if len(tags) > 0 { if tags[0] == "-" { continue } - if (tags[0] == "extends") && - (fieldType.Kind() == reflect.Struct) && - t.Field(i).Anonymous { + if (strings.ToUpper(tags[0]) == "EXTENDS") && + (fieldType.Kind() == reflect.Struct) { parentTable := engine.MapType(fieldType) for name, col := range parentTable.Columns { col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) table.Columns[name] = col } + + table.PrimaryKey = parentTable.PrimaryKey + continue } for j, key := range tags { - k := strings.ToLower(key) + k := strings.ToUpper(key) switch { case k == "<-": col.MapType = ONLYFROMDB case k == "->": col.MapType = ONLYTODB - case k == "pk": + case k == "PK": col.IsPrimaryKey = true col.Nullable = false - case k == "null": - col.Nullable = (tags[j-1] != "not") - case k == "autoincr": + table.PrimaryKey = col.Name + case k == "NULL": + col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT") + case k == "AUTOINCR": col.IsAutoIncrement = true - case k == "default": + case k == "DEFAULT": col.Default = tags[j+1] - case k == "text": - col.SQLType = Text - case k == "blob": - col.SQLType = Blob - case strings.HasPrefix(k, "int"): - if k == "int" { - col.SQLType = Int - col.Length = Int.DefaultLength - col.Length2 = Int.DefaultLength2 - } else { - col.SQLType = Int - lens := k[len("int")+1 : len(k)-1] - col.Length, _ = strconv.Atoi(lens) - } - case strings.HasPrefix(k, "varchar"): - if k == "varchar" { - col.SQLType = Varchar - col.Length = Varchar.DefaultLength - col.Length2 = Varchar.DefaultLength2 - } else { - col.SQLType = Varchar - lens := k[len("varchar")+1 : len(k)-1] - col.Length, _ = strconv.Atoi(lens) - } - case strings.HasPrefix(k, "decimal"): - col.SQLType = Decimal - lens := k[len("decimal")+1 : len(k)-1] - twolen := strings.Split(lens, ",") - col.Length, _ = strconv.Atoi(twolen[0]) - col.Length2, _ = strconv.Atoi(twolen[1]) - case strings.HasPrefix(k, "index"): - if k == "index" { + case strings.HasPrefix(k, "INDEX"): + if k == "INDEX" { col.IndexName = "" col.IndexType = SINGLEINDEX } else { - col.IndexName = k[len("index")+1 : len(k)-1] + col.IndexName = k[len("INDEX")+1 : len(k)-1] col.IndexType = UNIONINDEX } - case strings.HasPrefix(k, "unique"): - if k == "unique" { + case strings.HasPrefix(k, "UNIQUE"): + if k == "UNIQUE" { col.UniqueName = "" col.UniqueType = SINGLEUNIQUE } else { - col.UniqueName = k[len("unique")+1 : len(k)-1] + col.UniqueName = k[len("UNIQUE")+1 : len(k)-1] col.UniqueType = UNIONUNIQUE } - case k == "date": - col.SQLType = Date - case k == "float": - col.SQLType = Float - case k == "double": - col.SQLType = Double - case k == "datetime": - col.SQLType = DateTime - case k == "timestamp": - col.SQLType = TimeStamp - case k == "not": + case k == "NOT": default: - if k != col.Default { - col.Name = k + if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { + fs := strings.Split(k, "(") + if _, ok := sqlTypes[fs[0]]; !ok { + continue + } + col.SQLType = SQLType{fs[0], 0, 0} + fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") + if len(fs2) == 2 { + col.Length, _ = strconv.Atoi(fs2[0]) + col.Length2, _ = strconv.Atoi(fs2[1]) + } else if len(fs2) == 1 { + col.Length, _ = strconv.Atoi(fs2[0]) + } + } else { + if _, ok := sqlTypes[k]; ok { + col.SQLType = SQLType{k, 0, 0} + } else if k != col.Default { + col.Name = key + } } + engine.SqlType(col) } } if col.SQLType.Name == "" { @@ -353,24 +334,31 @@ func (engine *Engine) MapType(t reflect.Type) *Table { table.Uniques[col.UniqueName] = []string{col.Name} } } - - if col.IsPrimaryKey { - table.PrimaryKey = col.Name - } } } else { sqlType := Type2SQLType(fieldType) - col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, + col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true, "", NONEUNIQUE, "", NONEINDEX, "", false, false, TWOSIDES} - if col.Name == "id" { - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKey = col.Name - } + } + if col.IsAutoIncrement { + col.Nullable = false + } + if col.IsPrimaryKey { + table.PrimaryKey = col.Name } table.Columns[col.Name] = col + if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") { + idFieldColName = col.Name + } + } + + if idFieldColName != "" && table.PrimaryKey == "" { + col := table.Columns[idFieldColName] + col.IsPrimaryKey = true + col.IsAutoIncrement = true + col.Nullable = false + table.PrimaryKey = col.Name } return table diff --git a/filter.go b/filter.go index 8c4d4ea0..6f93aba6 100644 --- a/filter.go +++ b/filter.go @@ -37,7 +37,7 @@ type IdFilter struct { func (i *IdFilter) Do(sql string, session *Session) string { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { - return strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) + return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) } return sql } diff --git a/mysql.go b/mysql.go index 6fba8dc0..5acaefb3 100644 --- a/mysql.go +++ b/mysql.go @@ -7,26 +7,34 @@ package xorm -import "strconv" +import ( + "fmt" + "strconv" +) type mysql struct { } func (db *mysql) SqlType(c *Column) string { var res string - switch t := c.SQLType; t { + fmt.Println("-----", c.Name, c.SQLType.Name, "-----") + switch t := c.SQLType.Name; t { case Bool: - res = TinyInt.Name + res = TinyInt case Serial: c.IsAutoIncrement = true - res = Int.Name + c.IsPrimaryKey = true + c.Nullable = false + res = Int case BigSerial: c.IsAutoIncrement = true - res = Integer.Name + c.IsPrimaryKey = true + c.Nullable = false + res = BigInt case Bytea: - res = Blob.Name + res = Blob default: - res = t.Name + res = t } var hasLen1 bool = (c.Length > 0) diff --git a/postgres.go b/postgres.go index 2b7e3787..1a8040e9 100644 --- a/postgres.go +++ b/postgres.go @@ -14,31 +14,32 @@ type postgres struct { func (db *postgres) SqlType(c *Column) string { var res string - switch t := c.SQLType; t { + switch t := c.SQLType.Name; t { case TinyInt: - res = SmallInt.Name + res = SmallInt case MediumInt, Int, Integer: - return Integer.Name + return Integer case Serial, BigSerial: c.IsAutoIncrement = true - res = t.Name + c.Nullable = false + res = t case Binary, VarBinary: - res = Bytea.Name + return Bytea case DateTime: - res = TimeStamp.Name + res = TimeStamp case Float: - res = Real.Name + res = Real case TinyText, MediumText, LongText: - res = Text.Name + res = Text case Blob, TinyBlob, MediumBlob, LongBlob: - res = Bytea.Name + return Bytea case Double: return "DOUBLE PRECISION" default: if c.IsAutoIncrement { - return Serial.Name + return Serial } - res = t.Name + res = t } var hasLen1 bool = (c.Length > 0) diff --git a/session.go b/session.go index 7370a12a..a1c069c5 100644 --- a/session.go +++ b/session.go @@ -733,7 +733,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error colNames := make([]string, 0) colMultiPlaces := make([]string, 0) var args = make([]interface{}, 0) - cols := make([]Column, 0) + cols := make([]*Column, 0) for i := 0; i < size; i++ { elemValue := sliceValue.Index(i).Interface() @@ -864,13 +864,15 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { var args = make([]interface{}, 0) for _, col := range table.Columns { - fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue - } if col.MapType == ONLYFROMDB { continue } + + fieldValue := col.ValueOf(bean) + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue + } + if session.Statement.ColumnStr != "" { if _, ok := session.Statement.columnMap[col.Name]; !ok { continue @@ -906,8 +908,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } var id int64 = 0 - pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) - if pkValue.Int() != 0 || !pkValue.CanSet() { + pkValue := table.PKColumn().ValueOf(bean) + if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { return 0, nil } diff --git a/sqlite3.go b/sqlite3.go index 48e5a636..bd5712b5 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -11,24 +11,26 @@ type sqlite3 struct { } func (db *sqlite3) SqlType(c *Column) string { - switch t := c.SQLType; t { + switch t := c.SQLType.Name; t { case Date, DateTime, TimeStamp, Time: - return Numeric.Name + return Numeric case Char, Varchar, TinyText, Text, MediumText, LongText: - return Text.Name + return Text case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: - return Integer.Name + return Integer case Float, Double, Real: - return Real.Name + return Real case Decimal, Numeric: - return Numeric.Name + return Numeric case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: - return Blob.Name + return Blob case Serial, BigSerial: + c.IsPrimaryKey = true c.IsAutoIncrement = true - return Integer.Name + c.Nullable = false + return Integer default: - return t.Name + return t } } diff --git a/statement.go b/statement.go index cd64960a..b23dcff8 100644 --- a/statement.go +++ b/statement.go @@ -85,7 +85,7 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { - fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) val := fieldValue.Interface() switch fieldType.Kind() { diff --git a/table.go b/table.go index e313ee54..356aa015 100644 --- a/table.go +++ b/table.go @@ -10,7 +10,7 @@ package xorm import ( "reflect" //"strconv" - //"strings" + "strings" "time" ) @@ -21,46 +21,86 @@ type SQLType struct { } var ( - Bit = SQLType{"BIT", 0, 0} - TinyInt = SQLType{"TINYINT", 0, 0} - SmallInt = SQLType{"SMALLINT", 0, 0} - MediumInt = SQLType{"MEDIUMINT", 0, 0} - Int = SQLType{"INT", 0, 0} - Integer = SQLType{"INTEGER", 0, 0} - BigInt = SQLType{"BIGINT", 0, 0} + Bit = "BIT" + TinyInt = "TINYINT" + SmallInt = "SMALLINT" + MediumInt = "MEDIUMINT" + Int = "INT" + Integer = "INTEGER" + BigInt = "BIGINT" - Char = SQLType{"CHAR", 0, 0} - Varchar = SQLType{"VARCHAR", 64, 0} - TinyText = SQLType{"TINYTEXT", 0, 0} - Text = SQLType{"TEXT", 0, 0} - MediumText = SQLType{"MEDIUMTEXT", 0, 0} - LongText = SQLType{"LONGTEXT", 0, 0} - Binary = SQLType{"BINARY", 0, 0} - VarBinary = SQLType{"VARBINARY", 0, 0} + Char = "CHAR" + Varchar = "VARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" + Binary = "BINARY" + VarBinary = "VARBINARY" - Date = SQLType{"DATE", 0, 0} - DateTime = SQLType{"DATETIME", 0, 0} - Time = SQLType{"TIME", 0, 0} - TimeStamp = SQLType{"TIMESTAMP", 0, 0} + Date = "DATE" + DateTime = "DATETIME" + Time = "TIME" + TimeStamp = "TIMESTAMP" - Decimal = SQLType{"DECIMAL", 26, 2} - Numeric = SQLType{"NUMERIC", 0, 0} + Decimal = "DECIMAL" + Numeric = "NUMERIC" - Real = SQLType{"REAL", 0, 0} - Float = SQLType{"FLOAT", 0, 0} - Double = SQLType{"DOUBLE", 0, 0} - //Money = SQLType{"MONEY", 0, 0} + Real = "REAL" + Float = "FLOAT" + Double = "DOUBLE" - TinyBlob = SQLType{"TINYBLOB", 0, 0} - Blob = SQLType{"BLOB", 0, 0} - MediumBlob = SQLType{"MEDIUMBLOB", 0, 0} - LongBlob = SQLType{"LONGBLOB", 0, 0} - Bytea = SQLType{"BYTEA", 0, 0} + TinyBlob = "TINYBLOB" + Blob = "BLOB" + MediumBlob = "MEDIUMBLOB" + LongBlob = "LONGBLOB" + Bytea = "BYTEA" - Bool = SQLType{"BOOL", 0, 0} + Bool = "BOOL" - Serial = SQLType{"SERIAL", 0, 0} - BigSerial = SQLType{"BIGSERIAL", 0, 0} + Serial = "SERIAL" + BigSerial = "BIGSERIAL" + + sqlTypes = map[string]bool{ + Bit: true, + TinyInt: true, + SmallInt: true, + MediumInt: true, + Int: true, + Integer: true, + BigInt: true, + + Char: true, + Varchar: true, + TinyText: true, + Text: true, + MediumText: true, + LongText: true, + Binary: true, + VarBinary: true, + + Date: true, + DateTime: true, + Time: true, + TimeStamp: true, + + Decimal: true, + Numeric: true, + + Real: true, + Float: true, + Double: true, + TinyBlob: true, + Blob: true, + MediumBlob: true, + LongBlob: true, + Bytea: true, + + Bool: true, + + Serial: true, + BigSerial: true, + } ) var b byte @@ -69,29 +109,29 @@ var tm time.Time func Type2SQLType(t reflect.Type) (st SQLType) { switch k := t.Kind(); k { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - st = Int + st = SQLType{Int, 0, 0} case reflect.Int64, reflect.Uint64: - st = BigInt + st = SQLType{BigInt, 0, 0} case reflect.Float32: - st = Float + st = SQLType{Float, 0, 0} case reflect.Float64: - st = Double + st = SQLType{Double, 0, 0} case reflect.Complex64, reflect.Complex128: - st = Varchar + st = SQLType{Varchar, 64, 0} case reflect.Array, reflect.Slice: if t.Elem() == reflect.TypeOf(b) { - st = Blob + st = SQLType{Blob, 0, 0} } case reflect.Bool: - st = TinyInt + st = SQLType{Bool, 0, 0} case reflect.String: - st = Varchar + st = SQLType{Varchar, 64, 0} case reflect.Struct: if t == reflect.TypeOf(tm) { - st = DateTime + st = SQLType{DateTime, 0, 0} } default: - st = Varchar + st = SQLType{Varchar, 64, 0} } return } @@ -156,16 +196,32 @@ func (col *Column) String(engine *Engine) string { return sql } +func (col *Column) ValueOf(bean interface{}) reflect.Value { + var fieldValue reflect.Value + if strings.Contains(col.FieldName, ".") { + fields := strings.Split(col.FieldName, ".") + if len(fields) > 2 { + return reflect.ValueOf(nil) + } + + fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) + fieldValue = fieldValue.FieldByName(fields[1]) + } else { + fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + } + return fieldValue +} + type Table struct { Name string Type reflect.Type - Columns map[string]Column + Columns map[string]*Column Indexes map[string][]string Uniques map[string][]string PrimaryKey string } -func (table *Table) PKColumn() Column { +func (table *Table) PKColumn() *Column { return table.Columns[table.PrimaryKey] }