先前虽然能实现生成实体,但没有序列化ID,实体太多时手动创建很麻烦,所以这次加上了序列化ID的生成
无serialVersionUID版本
改进
参考jdk的生成方式添加了签名计算,直接上新增的代码
def outSerialVersionUID(out, fields, className){
out.println ""
out.println "\tprivate static final long serialVersionUID = ${calculateSerialId(packageName, className, fields)}L;"
}
long calculateSerialId(packageName, className, fields){
if (fields.size() == 0) return 0L
ByteArrayOutputStream bout = new ByteArrayOutputStream()
DataOutputStream dout = new DataOutputStream(bout)
dout.writeUTF(packageName.substring(0, packageName.length() - 1) + "." + className)
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF("java.io.Serializable")
def fieldsSigs = getFieldsSignatures(fields)
fieldsSigs.each(){
dout.writeUTF(it.name)
dout.writeInt(Modifier.PRIVATE)
dout.writeUTF(it.signature)
}
dout.writeUTF("<init>")
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF("()V")
def sigs = getMethodSignatures(fields, className)
sigs.each(){
dout.writeUTF(it.name)
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF(it.signature.replace('/', '.'))
}
dout.flush()
MessageDigest md = MessageDigest.getInstance("SHA")
byte[] hashBytes = md.digest(bout.toByteArray())
long hash = 0
for (int i = Math.min(hashBytes.length, 8) - 1; i >= 0; i--) {
hash = (hash << 8) | (hashBytes[i] & 0xFF)
}
return hash
}
def getFieldsSignatures(fields){
def sigs = []
fields.each(){
sigs.add([name : it.name, signature: getClassSignature(it.type)])
}
sigs.sort{a,b->
return a.name.compareTo(b.name)
}
return sigs
}
def getMethodSignatures(fields, className){
def sigs = []
def sig = [ name : "toString", signature : "()" + getClassSignature("String")]
sigs.add(sig)
fields.each(){
sig = [name : "get" + it.name.capitalize(), signature : "()" + getClassSignature(it.type)]
sigs.add(sig)
}
fields.each(){
sig = [name : "set" + it.name.capitalize(),
signature : "(" + getClassSignature(it.type) + ")L"
+ (packageName.substring(0, packageName.length() - 1)) + "." + className + ";"]
sigs.add(sig)
}
sigs.sort{a,b->
int comp = a.name.compareTo(b.name)
if (comp == 0) {
comp = a.signature.compareTo(b.signature)
}
return comp
}
return sigs
}
String getClassSignature(String type) {
String signature = fillTypeMapping.get(type)
signature = signature == null ? "L" + packageName.substring(0, packageName.length() - 1) : signature
signature = signature.replace(".", "/")
return 'L' + signature + ";"
}
对 groovy 这个语言不太了解 所以用法上不是最优,可以说是强行拼凑
完整版
import com.intellij.database.model.DasTable
import com.intellij.database.util.Case
import com.intellij.database.util.DasUtil
import java.lang.reflect.Modifier
import java.lang.reflect.Proxy
import java.security.MessageDigest
import java.text.SimpleDateFormat
packageName = ""
tableComment = ""
hasPrimaryKey = false
typeMapping = [
(~/(?i)bigint/) : "Long",
(~/(?i)int/) : "Integer",
(~/(?i)float|double|decimal|real/): "Double",
(~/(?i)datetime|timestamp/) : "Timestamp",
(~/(?i)date/) : "Date",
(~/(?i)time/) : "java.sql.Time",
(~/(?i)bit/) : "Boolean",
(~/(?i)/) : "String"
]
fillTypeMapping = [
"Long" : "java/lang/Long",
"Integer" : "java/lang/Integer",
"Double" : "java/lang/Double",
"Boolean" : "java/lang/Boolean",
"String" : "java/lang/String",
"Date" : "java/util/Date",
"java.sql.Time" : "java/sql/Time",
"Timestamp" : "java/sql/Timestamp"
]
FILES.chooseDirectoryAndSave(unicodeToString("\\u9009\\u62e9\\u5b9e\\u4f53\\u7c7b\\u5b58\\u653e\\u8def\\u5f84")
, unicodeToString("\\u7528\\u4e8e\\u5b58\\u653e\\u81ea\\u52a8\\u751f\\u6210\\u7684\\u6587\\u4ef6")) { dir ->
SELECTION.filter { it instanceof DasTable }.each { generate(it, dir) }
}
def generate(table, dir) {
def className = javaName(table.getName(), true)
def fields = calcFields(table)
packageName = getPackageName(dir)
PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(dir, className + ".java")), "UTF-8"))
printWriter.withPrintWriter {out -> generate(out, className, fields)}
}
def getPackageName(dir) {
return (dir.toString().replaceAll("\\\\", ".").replaceAll("/", ".").replaceAll("^.*src(\\.main\\.java\\.)?", "") + ";")
}
def generate(out, className, fields) {
out.println "package $packageName"
out.println "import java.io.Serializable;"
Set types = new HashSet()
fields.each() {
types.add(it.type)
}
if (types.contains("Timestamp")) {
out.println "import java.sql.Timestamp;"
}
if (types.contains("Date")) {
out.println "import java.util.Date;"
}
out.println "/**"
Locale.setDefault(Locale.CHINA)
sdf = new SimpleDateFormat()
out.println " * ${unicodeToString('\\u7531\\u0047\\u0072\\u006f\\u006f\\u0076\\u0079\\u81ea\\u52a8\\u751f\\u6210')} "
out.println " * <p>Date: " + sdf.format(new java.util.Date()) + ".</p>"
out.println " * Description: $tableComment"
out.println " * @author Mr.Wang"
out.println " */"
out.println "public class $className implements Serializable {"
outSerialVersionUID(out, fields, className)
fields.each() {
out.println ""
out.println "\t/**"
out.println "\t * ${it.comment}"
out.println "\t */"
out.println "\tprivate ${it.type} ${it.name};"
}
out.println ""
out.println "\tpublic $className(){}"
out.println ""
fields.each() {
out.println "\t/**"
out.println "\t * ${unicodeToString('\\u83b7\\u53d6')} ${it.comment}"
out.println "\t * @return null or ${it.name}"
out.println "\t */"
out.println "\tpublic ${it.type} get${it.name.capitalize()}() {"
out.println "\t return ${it.name};"
out.println "\t}"
out.println ""
out.println "\t/**"
out.println "\t * ${unicodeToString('\\u8bbe\\u7f6e')} ${it.comment.toString()}"
out.println "\t * @param ${it.name} ${it.comment}"
out.println "\t * @return " + unicodeToString("\\u5f53\\u524d\\u5bf9\\u8c61")
out.println "\t */"
out.println "\tpublic $className set${it.name.capitalize()}(${it.type} ${it.name}) {"
out.println "\t this.${it.name} = ${it.name};"
out.println "\t return this;"
out.println "\t}"
out.println ""
}
outToString(out, className, fields)
out.println "}"
}
def outToString(out, className, fields){
out.println "\t@Override"
out.println "\tpublic String toString() {"
out.print "\t\treturn \"$className ["
boolean firstOut = false
fields.each(){
if(!firstOut){
out.println "${it.name} = \" + " + it.name
firstOut = true
}else out.println "\t\t + \", ${it.name} = \" + " + it.name
}
out.println "\t\t + \"]\";"
out.println "\t}"
}
def calcFields(table) {
hasPrimaryKey = DasUtil.getPrimaryKey(table) != null
tableComment = table.getComment()
DasUtil.getColumns(table).reduce([]) { fields, col ->
def spec = Case.LOWER.apply(col.getDataType().getSpecification())
def typeStr = typeMapping.find { p, t -> p.matcher(spec).find() }.value
fields.add([
colName : col.getName(),
name : javaName(col.getName(), false),
type : typeStr,
comment: getComment(col.getComment()),
isPk : DasUtil.isPrimary(col),
isFk : DasUtil.isIndexColumn(col)
])
return fields
}
}
def javaName(str, capitalize) {
def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
.collect { Case.LOWER.apply(it).capitalize() }
.join("")
.replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
capitalize || s.length() == 1? s : Case.LOWER.apply(s[0]) + s[1..-1]
}
def isNotEmpty(content) {
return content != null && content.toString().trim().length() > 0
}
def getComment(comment){
if (isNotEmpty(comment)) {
return comment.toString()
}
return ""
}
def unicodeToString(String unicode) {
StringBuffer sb = new StringBuffer()
String[] hex = unicode.split("\\\\u")
for (int i = 1; i < hex.length; i++) {
int index = Integer.parseInt(hex[i], 16)
sb.append((char) index)
}
return sb.toString()
}
def outSerialVersionUID(out, fields, className){
out.println ""
out.println "\tprivate static final long serialVersionUID = ${calculateSerialId(packageName, className, fields)}L;"
}
long calculateSerialId(packageName, className, fields){
if (fields.size() == 0) return 0L
ByteArrayOutputStream bout = new ByteArrayOutputStream()
DataOutputStream dout = new DataOutputStream(bout)
dout.writeUTF(packageName.substring(0, packageName.length() - 1) + "." + className)
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF("java.io.Serializable")
def fieldsSigs = getFieldsSignatures(fields)
fieldsSigs.each(){
dout.writeUTF(it.name)
dout.writeInt(Modifier.PRIVATE)
dout.writeUTF(it.signature)
}
dout.writeUTF("<init>")
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF("()V")
def sigs = getMethodSignatures(fields, className)
sigs.each(){
dout.writeUTF(it.name)
dout.writeInt(Modifier.PUBLIC)
dout.writeUTF(it.signature.replace('/', '.'))
}
dout.flush()
MessageDigest md = MessageDigest.getInstance("SHA")
byte[] hashBytes = md.digest(bout.toByteArray())
long hash = 0
for (int i = Math.min(hashBytes.length, 8) - 1; i >= 0; i--) {
hash = (hash << 8) | (hashBytes[i] & 0xFF)
}
return hash
}
def getFieldsSignatures(fields){
def sigs = []
fields.each(){
sigs.add([name : it.name, signature: getClassSignature(it.type)])
}
sigs.sort{a,b->
return a.name.compareTo(b.name)
}
return sigs
}
def getMethodSignatures(fields, className){
def sigs = []
def sig = [ name : "toString", signature : "()" + getClassSignature("String")]
sigs.add(sig)
fields.each(){
sig = [name : "get" + it.name.capitalize(), signature : "()" + getClassSignature(it.type)]
sigs.add(sig)
}
fields.each(){
sig = [name : "set" + it.name.capitalize(),
signature : "(" + getClassSignature(it.type) + ")L"
+ (packageName.substring(0, packageName.length() - 1)) + "." + className + ";"]
sigs.add(sig)
}
sigs.sort{a,b->
int comp = a.name.compareTo(b.name)
if (comp == 0) {
comp = a.signature.compareTo(b.signature)
}
return comp
}
return sigs
}
String getClassSignature(String type) {
String signature = fillTypeMapping.get(type)
signature = signature == null ? "L" + packageName.substring(0, packageName.length() - 1) : signature
signature = signature.replace(".", "/")
return 'L' + signature + ";"
}
|