if-else泛滥成灾?试试【策略模式】优雅的给去掉吧!

  |   0 评论   |   浏览

hi,大家好,我是mbb

前段时间,和大家分享了一个关于如何优雅使用if-else的文章,之后陆陆续续好几个小伙伴微信给我留言聊最后那一段,说没有看明白,那么今天就来针对性的整理一下;答应粉丝的事情,必须得完成的。

示例源码地址: https://gitee.com/pengfeilu/strategy-demo

首先来回顾一下前篇文章的那段问题,以下是原文;到底该如何优雅的替代if-else?本文的目的也就是通过详细的示例,把这个细节给说清楚:

扩展应用程序,完全避免使用 If-Else

这是一个稍微高级的示例。通过用对象替换它们,知道何时甚至完全消除 If。

通常,您会发现自己不得不扩展应用程序的某些部分。作为初级开发人员,您可能会倾向于通过添加额外的 If-Else(即 else-if)语句来做到这一点。

举这个说明性的例子。在这里,我们需要将 Order 实例显示为字符串。首先,我们只有两种字符串表示形式:JSON 和纯文本。

在此阶段使用 If-Else 并不是什么大问题,如果我们可以轻松替换其他,只要如前所述即可。

知道我们需要扩展应用程序的这一部分,这种方法绝对是不可接受的。

上面的代码不仅违反了"打开/关闭"原则,而且阅读得不好,还会引起可维护性方面的麻烦。

正确的方法是遵循 SOLID 原则的方法,我们通过实施动态类型发现过程(在本例中为策略模式)来做到这一点。

重构这个混乱的过程的过程如下:

  • 使用公共接口将每个分支提取到单独的策略类中。
  • 动态查找实现通用接口的所有类。
  • 根据输入决定执行哪种策略。

替换上面示例的代码如下所示。是的,这是更多代码的方式。它要求您了解类型发现的工作原理。但是动态扩展应用程序是一个高级主题。

我只显示将替换 If-Else 示例的确切部分。如果要查看所有涉及的对象,请查看此要点。

image-20210519092009272

让我们快速浏览一下代码。方法签名保持不变,因为调用者不需要了解我们的重构。

首先,获取实现通用接口 IOrderOutputStrategy 的程序集中的所有类型。然后,我们建立一个字典,格式化程序的 displayName 的名称为 key,类型为 value。

然后从字典中选择格式化程序类型,然后尝试实例化策略对象。最后,调用策略对象的 ConvertOrderToString。

这是一篇译文,如果单看这一段描述和示例,确实有一点点懵!

其实文章已经把想表达的意思表达出来了,只是没有表达的特别清晰、详细,所以导致基础不是特别好的同学看起来就有那么点点吃力;但核心的意思已经总结在最后那一段,采用策略模式,将if-else给替换掉。

该如何替换呢?示例说明的并不是特别的清晰;下面就以一个和粉丝聊的示例,来详细的讲解一下这个过程。

策略模式

在讲解示例之前,既然知道要讲策略模式,就得了解一下他?

什么是策略模式?

官话:策略(Strategy)模式是定义了一系列算法,并将每个算法封装起来,使它们可以相互替换,且算法的变化不会影响使用算法的客户。策略模式属于对象行为模式,它通过对算法进行封装,把使用算法的责任和算法的实现分割开来,并委派给不同的对象对这些算法进行管理。

策略模式的结构图

模式的优缺点

  • 优点

    1. 多重条件语句不易维护,而使用策略模式可以避免使用多重条件语句,如 if...else 语句、switch...case 语句。
    2. 策略模式提供了一系列的可供重用的算法族,恰当使用继承可以把算法族的公共代码转移到父类里面,从而避免重复的代码。
    3. 策略模式可以提供相同行为的不同实现,客户可以根据不同时间或空间要求选择不同的。
    4. 策略模式提供了对开闭原则的完美支持,可以在不修改原代码的情况下,灵活增加新算法。
    5. 策略模式把算法的使用放到环境类中,而算法的实现移到具体策略类中,实现了二者的分离。
  • 缺点

    1. 客户端必须理解所有策略算法的区别,以便适时选择恰当的算法类。
    2. 策略模式造成很多的策略类,增加维护难度。

场景及基础准备

理论的东西了解了之后,当然得实操一遍,开发过程中,到底如何通过策略将if-else给去掉呢?

示例场景

VIP是很多系统都有的一种奖励机制,用户不同的VIP等级,有着不同的特权或福利;这里就以电商VIP折扣的这么一个场景来作为示例进行讲解;如果让你去开发这个功能模块的话,就必定会遇到下面这个问题?

问题:如何通过用户的vip等级,对商品的价格给予一定的折扣?

基础代码

  • 定义一个VIP的接口
    定义一个用来获取折扣的价格的方法,具体如何折扣,由各个VIP等级的实例自己去实现;

    @Service
    public interface VipService {
    
        /**
         * 获取折扣价
         *
         * @param prive 商品加个 单位:分
         * @return
         */
        Integer getPrice(Integer prive);
    }
    
  • 各个VIP等级的实现

    // 普通用户
    @Service("vip0")
    public class Vip0ServiceImpl implements VipService {
        @Override
        public Integer getPrice(Integer prive) {
            // 普通用户 原价
            return prive;
        }
    }
    
    // VIP1
    @Service("vip1")
    public class Vip1ServiceImpl implements VipService {
        @Override
        public Integer getPrice(Integer prive) {
            // 优惠1块
            return prive - 100;
        }
    }
    
    
    // VIP2
    @Service("vip2")
    public class Vip2ServiceImpl implements VipService {
        @Override
        public Integer getPrice(Integer prive) {
            // 优惠2块
            return prive - 200;
        }
    }
    
    // VIP3
    @Service
    public class Vip3ServiceImpl implements VipService {
        @Override
        public Integer getPrice(Integer prive) {
            return prive - 400;
        }
    }
    
    //....
    

    有了基础的接口和具体的折扣实现之后,就得根据用户的会员等级去用调用不同的会员实现(策略),得到折后价格

有了对比之后才能更直观的看出好坏,所以这里把两种方式都写在这里,就能很明显的感觉到差异

基于if-else的获取折扣价

这是一种最容易想到的实现方案,如果是新手小伙伴,可能第一时间想到的也是它了;

示例中的@Service("vipx")是为了后续通过spring获取使用,这里用不上;

public Integer getPrice1(String vipLevel, Integer price) {

    VipService vipService = new Vip0ServiceImpl();
    if ("vip1".equals(vipLevel)) {
        vipService = new Vip1ServiceImpl();
    } else if ("vip2".equals(vipLevel)) {
        vipService = new Vip2ServiceImpl();
    } else if ("vip3".equals(vipLevel)) {
        vipService = new Vip3ServiceImpl();
    } else if ("vipX".equals(vipLevel)) {
        // vipService = .....
    }

    return vipService.getPrice(price);
}

是不是很熟悉?都写过类似的吧?

  • 测试

    @Test
    public void getPrice1() {
        Integer vip1 = priceController.getPrice1("vip1", 10000);
        log.info("vip1 price:{}", vip1);
    
        Integer vip2 = priceController.getPrice1("vip2", 10000);
        log.info("vip2 price:{}", vip2);
    }
    


    效果已经达到了。

  • 那他有那些优缺点?

    • 优点理解容易,实现简单;这种写法,我想每个人或多或少也都有写过!
    • 缺点
      不利于扩展
      新增算法(VIP等级)就会违背开闭原则
  • 新增等级
    当我需要新增一个vip等级的时候,必须在这里加上一个if-else的判断;否则就没办法使用这个折扣;

    else if ("vipX".equals(vipLevel)) {
        vipService = new VipXServiceImpl();
    }
    

基于spring的策略使用

实现类加上@Service注解

目的是让Spring扫描类,并实例化缓存起来,方便使用的时候,直接根据vip标识给取出来;

示例中的实现类都已经加上了;当Spring Bean的作用域(scope)是singleton时,容器启动会将对象实例化并缓存起来;缓存的容器你可以理解为Map;当 @Service(“vip0”)指定了value(vip0)时,就会用value作为key进行缓存;如果没有指定value,就通过类的名称进行缓存;

调用

  • 上下文对象

    @Autowired
    ApplicationContext applicationContext;
    
  • 根据不同的vip等级获取不同的策略
    根据vip的等级标识,在Spring容器的缓存中将对应的实现类取出来;根本不需要任何的if-else的方法;

    public Integer getPrice2(String vipLevel, Integer price) {
        VipService vipService = applicationContext.getBean(vipLevel, VipService.class);
        return vipService.getPrice(price);
    }
    

    核心代码

    applicationContext.getBean(name, VipService.class);
    

    第一个参数:name就是缓存时使用的key,一开头有说过大概的规则;
    第二个参数Class指定具体的接口;

    这样就可以根据vipLevel拿到具体的实现;具体getBean的细节,这里就不展开了,涉及到Spring源码部分,不是本文的重点;

  • 测试

    @Test
    public void getPrice2() {
        Integer vip1 = priceController.getPrice2("vip1", 10000);
        log.info("vip1 price:{}", vip1);
    
        Integer vip3 = priceController.getPrice2("vip3ServiceImpl", 10000);
        log.info("vip3 price:{}", vip3);
    }
    

  • 优缺点

    • 优点
      代码简洁,可扩展性强
    • 缺点
      理解难度增减;但是,核心代码,Spring给封装好了,直接使用就好;
  • 扩展
    如果那天,产品经理需要你加个VIP4,就只需要加一个VIP4的实现

    @Service("vip4")
    public class Vip4ServiceImpl implements VipService{
        @Override
        public Integer getPrice(Integer prive) {
            return prive-600;
        }
    }
    

    使用方就可以直接调用;

    Integer vip4 = priceController.getPrice2("vip4", 10000);
    log.info("vip4 price:{}", vip4);
    

通过这两种实现方式的比较,文章一开始的那个问题,应该就可以明白了吧!同时也实现了通过策略模式完美去掉if-else的目标;

虽然说,前面通过Spring,实现了我们想要的,但是核心的环境类(context)并不是我们写的,Spring都给封装好了;如果那天不用Spring框架,又该怎么办呢?

接下来,为了能够更加透彻的了解被封装起来的那段逻辑;我们就仿着Spring的流程,来自己写一个环境类,MyContext;

自己写的难点和思路

说明一下,这里是一个简单的实现方案,咱不可能上来就写个Spring,不现实;重点是思路。

自己实现的难点在哪里?

我觉得难点就是下面的这行环境类(context)代码,虽然只有一行代码,但是做了最核心的事情;

applicationContext.getBean(vipLevel, VipService.class);

难点分析:

  1. 如何找到所有实现了VipService接口的实现类?
  2. 如何通过vipLevel找到对应的实现类?
  3. 如何对扫描后的数据进行缓存?

解决难点的思路

根据上面的问题,就来理一下解决难点的思路:

  1. 将所有实现了VipService接口的类全部找出来;
  2. 通过一些方式将vipLevel和对应的实现类关联;

    Spring用了注解,那我们这里也就用注解去实现;但是不限于注解,也可以采用其他方式

  3. 将vipLevel和实现类给缓存起来,方便后续使用的时候

编码

自定义注解

  • 自定义一个类似于@Service的注解@MyService
    作用和@Service一样,用来标明Service的具体实现

    @Target({ElementType.TYPE})
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface MyService {
        String value() default "";
    }
    
  • 将自定义的直接添加到各个VIP实现类上

    @Service("vip0")
    @MyService("vip0")
    public class Vip0ServiceImpl implements VipService {
        @Override
        public Integer getPrice(Integer prive) {
            // 普通用户 原价
            return prive;
        }
    }
    
    // 
    @Service("vip1")
    @MyService("vip1")
    public class Vip1ServiceImpl implements VipService{
    
    }
    
    //.....
    

定义属于自己的环境类对象MyContext

该类的getBean方法主要就是做以下3件事情

public class MyContext {
    public static <T> T getBean(String label, Class<T> clz) {
        // 1. 扫描出所有clz接口的具体实现
        // 2. 找到label与具体实现之际的关联
        // 3. 缓存并返回具体的实现
    }
}

接下来就对这三件事情做详细解读

扫描所有的接口实现

根据指定的接口,去找到他对用的所有的实现类;

  • 扫描的工具类
    工具类的作用就是根据指定的包的路径,去扫描出包下面所有的class;明白作用就行,没必要仔细去看这段代码

    public class ClassUtils {
    
        /**
         * 从包package中获取所有的Class
         *
         * @param packageName
         * @return
         */
        public static List<Class<?>> getClasses(String packageName) {
            //第一个class类的集合
            List<Class<?>> classes = new ArrayList<Class<?>>();
    
            //是否循环迭代
            boolean recursive = true;
    
            //获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
    
            //定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
    
            try {
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                //循环迭代下去
                while (dirs.hasMoreElements()) {
                    //获取下一个元素
                    URL url = dirs.nextElement();
                    //得到协议的名称
                    String protocol = url.getProtocol();
                    //如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        //获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        //以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                    } else if ("jar".equals(protocol)) {
                        //如果是jar包文件
    
                        //定义一个JarFile
                        JarFile jar;
                        try {
                            //获取jar
                            jar = ((JarURLConnection) url.openConnection()).getJarFile();
                            //从此jar包 得到一个枚举类
                            Enumeration<JarEntry> entries = jar.entries();
                            //同样的进行循环迭代
                            while (entries.hasMoreElements()) {
                                //获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                                JarEntry entry = entries.nextElement();
                                String name = entry.getName();
                                //如果是以/开头的
                                if (name.charAt(0) == '/') {
                                    //获取后面的字符串
                                    name = name.substring(1);
                                }
    
                                //如果前半部分和定义的包名相同
                                if (name.startsWith(packageDirName)) {
                                    int idx = name.lastIndexOf('/');
                                    //如果以"/"结尾 是一个包
                                    if (idx != -1) {
                                        //获取包名 把"/"替换成"."
                                        packageName = name.substring(0, idx).replace('/', '.');
                                    }
                                    //如果可以迭代下去 并且是一个包
                                    if ((idx != -1) || recursive) {
                                        //如果是一个.class文件 而且不是目录
                                        if (name.endsWith(".class") && !entry.isDirectory()) {
                                            //去掉后面的".class" 获取真正的类名
                                            String className = name.substring(packageName.length() + 1, name.length() - 6);
                                            try {
                                                //添加到classes
                                                classes.add(Class.forName(packageName + '.' + className));
                                            } catch (ClassNotFoundException e) {
                                                e.printStackTrace();
                                            }
                                        }
                                    }
                                }
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
    
            return classes;
        }
    
    
        /**
         * 以文件的形式来获取包下的所有Class
         *
         * @param packageName
         * @param packagePath
         * @param recursive
         * @param classes
         */
    
        public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, List<Class<?>> classes) {
            //获取此包的目录 建立一个File
            File dir = new File(packagePath);
            //如果不存在或者 也不是目录就直接返回
            if (!dir.exists() || !dir.isDirectory()) {
                return;
            }
            //如果存在 就获取包下的所有文件 包括目录
            File[] dirfiles = dir.listFiles(new FileFilter() {
                //自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
                public boolean accept(File file) {
                    return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
                }
            });
    
            //循环所有文件
            for (File file : dirfiles) {
                //如果是目录 则继续扫描
                if (file.isDirectory()) {
                    findAndAddClassesInPackageByFile(packageName + "." + file.getName(),
                            file.getAbsolutePath(),
                            recursive,
                            classes);
                } else {
                    //如果是java类文件 去掉后面的.class 只留下类名
                    String className = file.getName().substring(0, file.getName().length() - 6);
                    try {
                        //添加到集合中去
                        classes.add(Class.forName(packageName + '.' + className));
                    } catch (ClassNotFoundException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }
    
  • 获取所有的class

    // 为了减少扫描次数,这里就只扫描class所处的包及子包的class;可以根据需要调整
    String clzPackage = clz.getPackage().getName();
    List<Class<?>> classes = ClassUtils.getClasses(clz.getPackage().getName());
    
  • 找到实现了clz接口的所有实现类

    for (Class<?> scanClass : classes) {
        /**
         * scanClass.isInterface() 用来判断类是否是一个接口
         * clz.isAssignableFrom(scanClass) 判断scanClass是否实现了clz这个接口
         * Modifier.isAbstract(scanClass.getModifiers()) 判断是个抽象类
         */
        if (clz.isAssignableFrom(scanClass) && scanClass.isInterface() && !Modifier.isAbstract(scanClass.getModifiers())) {
            // 如果当前scanClass实现了clz
            // 并且 scanClass不是接口
            // 并且 clz不是抽象类
            
            // 这样就是一个符合条件的class
        }
    }
    

根据注解建立label与实现的关联

  • 获取class所有注解
    // 获取所有的注解
    Annotation[] annotations = scanClass.getAnnotations();
    // 是否有已经添加了MyService的标识
    boolean isAddMyServiceAnno = false;
    MyService myService = null;
    for (Annotation annotation : annotations) {
        // 判断所有注解里面有没有@MyService注解
        if (null != annotation && annotation instanceof MyService) {
            // 有的话就做好标记并暂存一个注解对象
            isAddMyServiceAnno = true;
            myService = (MyService) annotation;
            // 跳出循环
            break;
        }
    }
    
    // 判断当前实现是否有加MyService注解
    if (isAddMyServiceAnno && null != myService) {
        // 获取类的名称 作为默认的缓存key
        String objName = scanClass.getName();
        // 有的话获取指定的名称
        if (null != myService.value()) {
            // 如果指定的缓存名称,就用指定的
            objName = myService.value();
        }
    
        // 通过反射实例化对象
        // 这里并没有去管属性的注入
        Object o = scanClass.newInstance();
        // 缓存起来
        objs.put(objName, o);
    }
    

缓存

前面的步骤已经拿到label和对应实现类的关系;并且实现类已经实例化了;下一步要做的,就是将他们缓存起来

为了使这个MyContext对象能适用性更广,这里使用了一个嵌套Map去缓存

private static Map<String, Map<String, Object>> objCache = new HashMap<>();

第一层的Map:Key为接口的路径,Value的Map对应这个接口的所有实现的集合

第二层的Map;key为指定的名称,Value的Object为具体的实现

缓存的过程

  1. 扫描出接口所有的实现类
  2. 循环找出lebel与实现类的对应关系,并实例化对象
  3. 以label作为Key;实现类作为Value缓存在Map<String, Object>
  4. 循环找完所有实现之后,在以接口的路径作为key,Map<String, Object>作为value,进行缓存;

MyContext中getBean的具体代码

getBean的作用就是去查找缓存,没缓存就去扫描类并缓存;有缓存之后就直接将缓存中的类取出来并返回使用

public static <T> T getBean(String label, Class<T> clz) {
    if (null == label || null == clz) {
        return null;
    }

    String clzPackage = clz.getPackage().getName();
    // 首先在缓存中获取,看是否之前已经扫描过了
    Map<String, Object> stringObjectMap = objCache.get(clzPackage);
    if (null == stringObjectMap) {
        // 如果为空,说明之前没有扫描;扫描一遍
        stringObjectMap = scanClass(clz);
    }

    Object obj = null;
    // 判断是否扫描到实现类了
    if (null != stringObjectMap) {
        // 在所有的实现类中缓存集合中找到具体的实现
        obj = stringObjectMap.get(label);
    }

    // 返回之前,再次确认一下 
    // 是不是为空
    // 对象是不是实现了clz的接口
    if (null != obj && clz.isAssignableFrom(obj.getClass())) {
        return (T) obj;
    }

    // 这里根据情况 看是否要抛异常
    return null;
}

测试

@Test
public void getPrice3() {
    VipService vip2 = MyContext.getBean("vip2", VipService.class);
    log.info("vip2 price:{}", vip2.getPrice(10000));
    VipService vip1 = MyContext.getBean("vip1", VipService.class);
    log.info("vip1 price:{}", vip1.getPrice(10000));
}

到这里,不使用Spring,我们想要的效果已经达到了;

这只是一个基础的实现,并没有Spring提供的那么完善和健壮;这里更多是给了一个实现的思路,希望能帮到你;

总结

本文虽然说的是要去掉if-else;但是并没有任何说if-else不好的意思;if-else作为分支语法,是不可或缺的;但我更想说的是实际的开发过程中能熟练的使用一些小技巧,搭配点设计模式等,可以让代码不管是结构上、还是思路逻辑上,都会变的清晰优雅很多;要做到熟练这一点,需要平时不断的去思考,练习;不管你处于什么水平,把一段时间之前的代码拿出来读一读,总能发现一些可以改进的地方;如此往复的完善,量变终究能带来质变!

测试源码地址:https://gitee.com/pengfeilu/strategy-demo



标题:if-else泛滥成灾?试试【策略模式】优雅的给去掉吧!
作者:码霸霸
地址:https://blog.lupf.cn/articles/2021/05/19/1621388005047.html