基于springBoot 实现一个rpc框架
需要用到的技术点:
自定义springBootStarter。
自定义starter读取配置。
自定义starter 向引入依赖的项目中注册bean。
基于FactoryBean 实现的JDK动态代理。
基于Netty实现的远程通信。
基于spring原理实现自定义的bean注入,以及bean的属性注入。
目标
通过自定义的rpc框架,实现类似与dubbo的远程方法调用。
1、需要有一个自定义的注册中心 rpc-server
2、需要一个自定义的springboot-starter
注册中心
注册功能
@PostMapping(value = "/liz/rpc/register")
public void providerRegister(@RequestBody RegisterEntity entity){
log.info("接口访问:/liz/rpc/register,data:{}",entity);
this.registerService.register(entity);
}
public void register(RegisterEntity entity) {
String clientId = entity.getClientId();
String ip = entity.getIp();
if (StringUtils.isEmpty(ip) || StringUtils.isEmpty(clientId)){
throw new RuntimeException("参数不全");
}
Map<String, List<String>> classMethodNameMap = entity.getClassMethodNameMap();
for (Map.Entry<String, List<String>> entry : classMethodNameMap.entrySet()) {
String className = entry.getKey();
Map<String,Node> nodeMap = GlobalHolder.classNodeMap.computeIfAbsent(className, k -> Maps.newConcurrentMap());
Node node = nodeMap.computeIfAbsent(clientId, key -> new Node());
node.setIp(ip);
node.setMethodList(entry.getValue());
node.setProtocolPort(entity.getProtocolPort());
}
}
发现功能
@GetMapping(value = "/liz/rpc/remote-class-node")
public Map<String, Map<String,Node>> remoteClassNode(){
log.info("接口访问:/liz/rpc/remote-class-node");
return GlobalHolder.classNodeMap;
}
自定义starter
注解
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.METHOD, ElementType.ANNOTATION_TYPE})
public @interface CustomerResource {
}
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Inherited
public @interface ProviderService {
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
@Documented
@Import(LizRpcComponentScanRegistrar.class)
public @interface LizRpcComponentScan {
String basePackage() default "";
}
动态代理功能
需要将spring bean中带有 CustomerResource 注解的属性,创建一个动态代理对象 设置到bean属性上
ResourceBean
public class ResourceBean<T> implements FactoryBean<T> {
private Class<T> interfaceType;
public ResourceBean(Class<T> interfaceType){
this.interfaceType = interfaceType;
}
@Override
public T getObject() throws Exception {
InvocationHandler handler = new ServiceProxy<>(interfaceType);
return (T) Proxy.newProxyInstance(interfaceType.getClassLoader(),new Class<?>[]{interfaceType},handler);
}
@Override
public Class<?> getObjectType() {
return interfaceType;
}
}
ServiceProxy
@Slf4j
public class ServiceProxy<T> implements InvocationHandler {
protected final Class<T> interfaceType;
public ServiceProxy(Class<T> interfaceType) {
this.interfaceType = interfaceType;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (Object.class.equals(method.getDeclaringClass())) {
return method.invoke(this, args);
}
log.info("调用前,参数:{}",Arrays.toString(args));
Object res = this.getResultFromRemote(interfaceType,method,args);
log.info("调用后,结果:{}", res);
return res;
}
@SuppressWarnings("ALL")
private Object getResultFromRemote(Class<T> interfaceType, Method method, Object[] args) {
String proxyType = GlobalConst.getProxyType();
LizAbstractProxy lizAbstractProxy = ProxyTypeEnum.getProxy(proxyType);
return lizAbstractProxy.getResult(interfaceType,method,args);
}
}
Proxy
public class LizNettyProxy extends NettyBase implements LizAbstractProxy {
@Override
public <T> Object getResult(Class<T> interfaceType, Method method, Object[] args) {
Map<String, Node> nodeMap = LizRpcConfig.classNodeMap.get(interfaceType.getName());
Node node = this.getNodeBalanced(nodeMap);
return super.getResultFromRemoteNode(interfaceType,method,args,node);
}
protected Node getNodeBalanced(Map<String, Node> nodeMap) {
Node node = null;
for (Map.Entry<String, Node> entry : nodeMap.entrySet()) {
node = entry.getValue();
break;
}
return node;
}
}
通过调用代理方法的getResult方法,获得远程调用的结果
NettyBase
public class NettyBase extends BaseProxy {
protected Object getResultFromRemoteNode(Class interfaceType, Method method, Object[] args, Node node) {
String ip = node.getIp();
String protocolPort = node.getProtocolPort();
List<LizRpcArg> lizRpcArgs = super.generateArgList(args);
InvokerMessage message = new InvokerMessage();
message.setClassName(interfaceType.getName());
message.setMethodName(method.getName());
message.setLizRpcArgs(lizRpcArgs);
List<Object> params = getArgs(lizRpcArgs);
message.setParameterTypes(getParamsType(params.toArray()));
final RpcProxyHandler customerHandler = new RpcProxyHandler();
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group);
bootstrap.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer() {
@Override
protected void initChannel(Channel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast(new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
pipeline.addLast("myHandler", customerHandler);
}
});
ChannelFuture future = bootstrap.connect(ip, Integer.parseInt(protocolPort)).sync();
future.channel().writeAndFlush(message).sync();
future.channel().closeFuture().sync();
} catch (Exception e) {
System.out.println(e.getMessage());
} finally {
group.shutdownGracefully();
}
return customerHandler.getResponse();
}
private Class<? extends Object>[] getParamsType(Object[] params){
Class<? extends Object>[] paramClass = null;
if (params != null) {
int paramsLength = params.length;
paramClass = new Class[paramsLength];
for (int i = 0; i < paramsLength; i++) {
paramClass[i] = params[i].getClass();
}
}
return paramClass;
}
private List<Object> getArgs(List<LizRpcArg> argsList) {
try {
List<Object> res = new ArrayList<>();
for (LizRpcArg arg : argsList) {
Class<?> clazz = Class.forName(arg.getClassName());
Object o = JSON.parseObject(arg.getObjStr(), clazz);
res.add(o);
}
return res;
} catch (Exception e) {
throw new RuntimeException("参数转化异常:" + argsList);
}
}
}
RpcProxyHandler
public class RpcProxyHandler extends ChannelInboundHandlerAdapter {
private Object result;
public Object getResponse(){
return result;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
result = msg;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("client exception is general");
}
}
BaseProxy
public abstract class BaseProxy {
protected List<LizRpcArg> generateArgList(Object[] args) {
List<LizRpcArg> res = new ArrayList<>();
for (Object arg : args) {
LizRpcArg lizRpcArg = new LizRpcArg();
String className = arg.getClass().getName();
lizRpcArg.setClassName(className);
lizRpcArg.setObjStr(JSON.toJSONString(arg));
res.add(lizRpcArg);
}
return res;
}
}
提供远程调用
public class LizRpcNettyProvider extends AbstractLizRpc implements LizPrcProvider {
private static Boolean nettyServerStartFlag = false;
private static Logger log = LoggerFactory.getLogger(LizRpcNettyProvider.class);
@Override
public void providerToRemote(List<String> classNameList, LizRpcProperties lizRpcProperties) {
if (!nettyServerStartFlag) {
nettyServerStartFlag = true;
ExecutorServiceComponent.getDefaultExecutorService().execute(() -> {
providerToRemote(lizRpcProperties);
});
}
}
public static void providerToRemote(LizRpcProperties lizRpcProperties) {
EventLoopGroup master = new NioEventLoopGroup();
EventLoopGroup slave = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(master, slave)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
ChannelPipeline pipeline = socketChannel.pipeline();
pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
pipeline.addLast(new LengthFieldPrepender(4));
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
pipeline.addLast(new MyReflectInvokingHandler());
}
});
ChannelFuture future = serverBootstrap.bind(Integer.parseInt(lizRpcProperties.getProtocolPort())).sync();
log.info("RPC registry start listen at :{}",lizRpcProperties.getProtocolPort());
future.channel().closeFuture().sync();
} catch (Exception e) {
e.printStackTrace();
} finally {
master.shutdownGracefully();
slave.shutdownGracefully();
}
}
}
public class MyReflectInvokingHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(MyReflectInvokingHandler.class);
@Override
public void channelRead(ChannelHandlerContext ctx, Object obj) throws Exception {
Object result = new Object();
InvokerMessage message = (InvokerMessage) obj;
String className = message.getClassName();
LizRpcConfig lizRpcConfig = SpringContextUtils.getApplicationContext().getBean(LizRpcConfig.class);
if (lizRpcConfig.classNameList.contains(className)){
return;
}
List<Object> args = this.getArgs(message.getLizRpcArgs());
Object[] params = args.toArray();
Class<?>[] paramsType = this.getParamsType(params);
Object clazz = SpringContextUtils.getApplicationContext().getBean(message.getClassName());
Method method = clazz.getClass().getMethod(message.getMethodName(), paramsType);
result = method.invoke(clazz, params);
log.info("message:{},result:{}",message,result);
ctx.writeAndFlush(result);
ctx.close();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
private List<Object> getArgs(List<LizRpcArg> argsList) {
try {
List<Object> res = new ArrayList<>();
for (LizRpcArg arg : argsList) {
Class<?> clazz = Class.forName(arg.getClassName());
Object o = JSON.parseObject(arg.getObjStr(), clazz);
res.add(o);
}
return res;
} catch (Exception e) {
throw new RuntimeException("参数转化异常:" + argsList);
}
}
private Class<? extends Object>[] getParamsType(Object[] params){
Class<? extends Object>[] paramClass = null;
if (params != null) {
int paramsLength = params.length;
paramClass = new Class[paramsLength];
for (int i = 0; i < paramsLength; i++) {
paramClass[i] = params[i].getClass();
}
}
return paramClass;
}
}
核心逻辑
自定义配置
@Data
@Component
@ConfigurationProperties(prefix = "liz.rpc")
public class LizRpcProperties {
private String host;
private String port;
private String basePackage;
private String applicationName;
private String protocolPort;
private String proxyType;
}
在自定义个Configuration中
1、扫描 LizRpcProperties.basePackage 下所有class。
2、找出有providerService注解的class
3、注册:有providerService注解的class (发送到rpc-server)
4、发现:在rpc-server中查找注册的类
5、将一个监听远程调用的 LizPrcProvider 注册到spring上下文中
6、提供远程调用,异步执行 LizPrcProvider 的 providerToRemote方法,监听远程调用
ProviderService注册成spring的bean
通过 LizRpcComponentScan注解,让spring找到 自定义的bean注册器 LizRpcComponentScanRegistrar
public class LizRpcComponentScanRegistrar implements ImportBeanDefinitionRegistrar {
@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
AnnotationAttributes mapperScanAttrs = AnnotationAttributes
.fromMap(importingClassMetadata.getAnnotationAttributes(LizRpcComponentScan.class.getName()));
String basePackage = mapperScanAttrs.getString("basePackage");
if (basePackage == null) {
return;
}
registerServiceAnnotationBeanPostProcessor(basePackage, registry);
registerReferenceAnnotationBeanPostProcessor(registry);
}
private void registerServiceAnnotationBeanPostProcessor(String basePackages, BeanDefinitionRegistry registry) {
BeanDefinitionBuilder builder = rootBeanDefinition(ProviderServiceAnnotationBeanPostProcessor.class);
builder.addConstructorArgValue(basePackages);
builder.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
AbstractBeanDefinition beanDefinition = builder.getBeanDefinition();
BeanDefinitionReaderUtils.registerWithGeneratedName(beanDefinition, registry);
}
private void registerReferenceAnnotationBeanPostProcessor(BeanDefinitionRegistry registry) {
RootBeanDefinition beanDefinition = new RootBeanDefinition(CustomerResourceAnnotationBeanPostProcessor.class);
beanDefinition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
registry.registerBeanDefinition(CustomerResourceAnnotationBeanPostProcessor.BEAN_NAME, beanDefinition);
}
}
ProviderServiceAnnotationBeanPostProcessor
public class ProviderServiceAnnotationBeanPostProcessor implements BeanDefinitionRegistryPostProcessor {
private final Logger logger = LoggerFactory.getLogger(getClass());
private String packagesToScan;
public String getPackagesToScan() {
return packagesToScan;
}
public void setPackagesToScan(String packagesToScan) {
this.packagesToScan = packagesToScan;
}
public ProviderServiceAnnotationBeanPostProcessor(String packagesToScan) {
this.packagesToScan = packagesToScan;
}
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
registerServiceBeans(packagesToScan, registry);
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {
}
private void registerServiceBeans(String packagesToScan, BeanDefinitionRegistry registry) {
List<String> classNameList = new ArrayList<>();
this.doScanPackage(packagesToScan,classNameList);
List<String> filterList = new ArrayList<>();
this.getAnnotationProviderServiceClass(classNameList,filterList);
for (String className : filterList) {
try {
Class<?> clazz = Class.forName(className);
Class<?>[] interfaces = clazz.getInterfaces();
String key = className;
if (interfaces.length > 0){
key = interfaces[0].getName();
}
BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(clazz);
registry.registerBeanDefinition(key, beanDefinitionBuilder.getRawBeanDefinition());
}catch (Exception e){
logger.warn("注入bean失败,className:{}", className);
}
}
}
private void doScanPackage(String basePackages,List<String> classNameList) {
String scanPath = basePackages.replaceAll("\\.", "/");
URL url = this.getClass().getClassLoader().getResource(scanPath);
assert url != null;
File files = new File(url.getFile());
for (File file : files.listFiles()) {
if (file.isDirectory()) {
doScanPackage(basePackages + "." + file.getName(),classNameList);
} else {
classNameList.add(basePackages + "." + file.getName().replace(".class", ""));
}
}
}
private void getAnnotationProviderServiceClass(List<String> classNameList, List<String> filterList) {
if (CollectionUtils.isEmpty(classNameList)) {
return;
}
Iterator<String> iterator = classNameList.iterator();
while (iterator.hasNext()) {
try {
String className = iterator.next();
Class clazz = Class.forName(className);
if (!clazz.isAnnotationPresent(ProviderService.class)) {
iterator.remove();
continue;
}
String name = className;
List<Method> methods = Arrays.asList(clazz.getMethods());
List<String> collect = methods.stream().map(Method::getName).collect(Collectors.toList());
filterList.add(name);
} catch (Exception e) {
logger.warn("扫描class异常:{}", iterator.next());
}
}
}
}
CustomerResource 注释处理
CustomerResource 修饰属性,创建动态代理设置到spring的Bean上
通过上面 动态代理功能实现。
public class CustomerResourceAnnotationBeanPostProcessor implements BeanPostProcessor {
private final Logger logger = LoggerFactory.getLogger(getClass());
public static final String BEAN_NAME = "referenceAnnotationBeanPostProcessor";
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
Class<?> aClass = bean.getClass();
Field[] fields = aClass.getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(CustomerResource.class)) {
continue;
}
String key = LizStringUtil.toLowerFirstLetterCase(field.getType().getName());
field.setAccessible(true);
try {
ResourceBean lizRpcProxy = this.getLizRpcProxy(key);
if (null == lizRpcProxy) {
} else {
Object object = lizRpcProxy.getObject();
field.set(bean, object);
}
} catch (Throwable t) {
logger.warn("注入失败-bean:{},field:{}", beanName, key);
}
}
return bean;
}
private ResourceBean getLizRpcProxy(String className) {
try {
Class<?> interfaceType = Class.forName(className);
return new ResourceBean<>(interfaceType);
} catch (Throwable t) {
logger.warn("create proxy Object failed,error:{}", t.getMessage());
return null;
}
}
}
使用方法:
1、启动 rpc-server 注册中心。
2、进入./core/liz-rpc-srping-boot-starter 目录 执行
mvn clean
mvn install -DskipTests
3、创建一个 demo-api 项目 用于声明一个方法的接口
public interface MyApi {
Integer add(Integer a, Integer b);
}
创建一个 demo-provider 项目 作为方法提供者
@ProviderService
public class MyApiImpl implements MyApi {
@Override
public Integer add(Integer a, Integer b) {
return a + b;
}
}
@LizRpcComponentScan(basePackage = "com.liz.demoprovider")
liz:
rpc:
host: http://localhost
port: 9999
base-package: com.liz.demoprovider
protocol-port: 20882
proxy-type: netty
application-name: provider
redis:
password: 123456
host: 127.0.0.1
创建一个 demo-customer 作为方法调用者
@RestController
public class CusController {
@CustomerResource
private MyApi myApi;
@GetMapping(value = "/cus/add")
public Integer add(@RequestParam("a") Integer a, @RequestParam("b") Integer b) {
System.out.println(this.getClass());
return myApi.add(a, b);
}
}
@LizRpcComponentScan(basePackage = "com.liz.democustomer")
liz:
rpc:
host: http://localhost
port: 9999
base-package: com.liz.democustomer
protocol-port: 20883
proxy-type: netty
application-name: customer
redis:
password: 123456
host: 127.0.0.1
server:
port: 8082
调用 demo-customer 项目的 /cus/add接口
都看到这了,点个赞吧
代码发布到gitee上,地址:https://gitee.com/lgw996699/rpc-frame.git
|