init
Browse files- app.py +218 -0
 - process.py +48 -0
 - requirements.txt +10 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,218 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import io
         
     | 
| 2 | 
         
            +
            import ssl
         
     | 
| 3 | 
         
            +
            import base64
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import streamlit as st
         
     | 
| 6 | 
         
            +
            import urllib.request
         
     | 
| 7 | 
         
            +
            import psutil
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            from process import process
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            st.set_page_config("Ai抠图(RMBG 2.0)", layout="wide")
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            st.markdown(
         
     | 
| 14 | 
         
            +
                """<style>
         
     | 
| 15 | 
         
            +
            .stDeployButton {
         
     | 
| 16 | 
         
            +
                visibility: hidden;
         
     | 
| 17 | 
         
            +
            }
         
     | 
| 18 | 
         
            +
            .block-container {
         
     | 
| 19 | 
         
            +
                padding: 3rem 2rem 2rem 2rem;
         
     | 
| 20 | 
         
            +
            }
         
     | 
| 21 | 
         
            +
            .st-emotion-cache-1mi2ry5 {
         
     | 
| 22 | 
         
            +
                padding: 0rem 1rem;
         
     | 
| 23 | 
         
            +
            }
         
     | 
| 24 | 
         
            +
            .st-emotion-cache-1gwvy71 {
         
     | 
| 25 | 
         
            +
                padding: 1rem 1rem;
         
     | 
| 26 | 
         
            +
            }
         
     | 
| 27 | 
         
            +
            </style>""",
         
     | 
| 28 | 
         
            +
                unsafe_allow_html=True,
         
     | 
| 29 | 
         
            +
            )
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            state = st.session_state
         
     | 
| 32 | 
         
            +
            if "image" not in state:
         
     | 
| 33 | 
         
            +
                state.image = ""
         
     | 
| 34 | 
         
            +
            if "image_nbg" not in state:
         
     | 
| 35 | 
         
            +
                state.image_nbg = ""
         
     | 
| 36 | 
         
            +
            if "mask" not in state:
         
     | 
| 37 | 
         
            +
                state.mask = ""
         
     | 
| 38 | 
         
            +
            if "filename" not in state:
         
     | 
| 39 | 
         
            +
                state.filename = ""
         
     | 
| 40 | 
         
            +
            if "image_stream" not in state:
         
     | 
| 41 | 
         
            +
                state.image_stream = None
         
     | 
| 42 | 
         
            +
            if "read_file_once" not in state:
         
     | 
| 43 | 
         
            +
                state.read_file_once = 0
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            IMAGE_FORMATS = ("jpg", "png", "jpeg", "JPG", "PNG", "JPEG")
         
     | 
| 47 | 
         
            +
            DEVICE = "GPU" if torch.cuda.is_available() else "CPU"
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            @st.dialog("上传图片")
         
     | 
| 51 | 
         
            +
            def upload_image(input_image_ph, output_image_ph):
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                # 网络图片
         
     | 
| 54 | 
         
            +
                st.markdown("**图片链接**", help="填写网络图片地址")
         
     | 
| 55 | 
         
            +
                cls = st.columns([0.8, 0.2])
         
     | 
| 56 | 
         
            +
                url = cls[0].text_input(
         
     | 
| 57 | 
         
            +
                    "xxx", placeholder="url/base64...", label_visibility="collapsed"
         
     | 
| 58 | 
         
            +
                )
         
     | 
| 59 | 
         
            +
                if cls[1].button(
         
     | 
| 60 | 
         
            +
                    "读取",
         
     | 
| 61 | 
         
            +
                    use_container_width=True,
         
     | 
| 62 | 
         
            +
                    # disabled=not url or not url.startswith(("https://", "data:image/")),
         
     | 
| 63 | 
         
            +
                ):
         
     | 
| 64 | 
         
            +
                    try:
         
     | 
| 65 | 
         
            +
                        if url.startswith(("https://", "http://")):
         
     | 
| 66 | 
         
            +
                            content = ssl._create_unverified_context()
         
     | 
| 67 | 
         
            +
                            with urllib.request.urlopen(url, context=content) as response:
         
     | 
| 68 | 
         
            +
                                image_data = response.read()
         
     | 
| 69 | 
         
            +
                                state.image_stream = io.BytesIO(image_data)
         
     | 
| 70 | 
         
            +
                                name = "image." + url.rsplit(".", 1)[-1]
         
     | 
| 71 | 
         
            +
                        elif url.startswith("data:image/"):
         
     | 
| 72 | 
         
            +
                            pfix, base64_data = url.split(",", 1)
         
     | 
| 73 | 
         
            +
                            state.image_stream = io.BytesIO(base64.b64decode(base64_data))
         
     | 
| 74 | 
         
            +
                            name = "image." + pfix[11:-7]
         
     | 
| 75 | 
         
            +
                        else:
         
     | 
| 76 | 
         
            +
                            st.warning(":red[请输入有效的图片链接]")
         
     | 
| 77 | 
         
            +
                    except Exception as e:
         
     | 
| 78 | 
         
            +
                        st.warning(f":red[**读取图片失败,请保存到本地后上传**]")
         
     | 
| 79 | 
         
            +
                        return
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                # 本地图片
         
     | 
| 82 | 
         
            +
                def _cb():
         
     | 
| 83 | 
         
            +
                    state.read_file_once = 1
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                st.markdown("**上传图片**")
         
     | 
| 86 | 
         
            +
                file = st.file_uploader(
         
     | 
| 87 | 
         
            +
                    "xxx",
         
     | 
| 88 | 
         
            +
                    accept_multiple_files=False,
         
     | 
| 89 | 
         
            +
                    type=IMAGE_FORMATS,
         
     | 
| 90 | 
         
            +
                    label_visibility="collapsed",
         
     | 
| 91 | 
         
            +
                    on_change=_cb,
         
     | 
| 92 | 
         
            +
                    key="upload_key",
         
     | 
| 93 | 
         
            +
                )
         
     | 
| 94 | 
         
            +
                if state.read_file_once and file:
         
     | 
| 95 | 
         
            +
                    state.image_stream = io.BytesIO(file.getvalue())
         
     | 
| 96 | 
         
            +
                    name = file.name
         
     | 
| 97 | 
         
            +
                    state.read_file_once = 0
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                if state.image_stream is not None:
         
     | 
| 100 | 
         
            +
                    try:
         
     | 
| 101 | 
         
            +
                        image = Image.open(state.image_stream)
         
     | 
| 102 | 
         
            +
                        state.image = image
         
     | 
| 103 | 
         
            +
                        state.mask = ""
         
     | 
| 104 | 
         
            +
                        state.image_nbg = ""
         
     | 
| 105 | 
         
            +
                        state.filename = name
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                        input_image_ph.image(image)
         
     | 
| 108 | 
         
            +
                        output_image_ph.empty()
         
     | 
| 109 | 
         
            +
                        st.success(":rainbow[**上传成功**]")
         
     | 
| 110 | 
         
            +
                    except Exception as e:
         
     | 
| 111 | 
         
            +
                        st.warning(f":red[处理图片出错 >> {e}]")
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    state.image_stream = None
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            @st.dialog("下载图片")
         
     | 
| 117 | 
         
            +
            def download_image():
         
     | 
| 118 | 
         
            +
                if not state.mask or not state.image_nbg:
         
     | 
| 119 | 
         
            +
                    st.warning("请上传图片")
         
     | 
| 120 | 
         
            +
                else:
         
     | 
| 121 | 
         
            +
                    with st.spinner("正在处理中..."):
         
     | 
| 122 | 
         
            +
                        buffer1 = io.BytesIO()
         
     | 
| 123 | 
         
            +
                        state.mask.save(buffer1, format="PNG")
         
     | 
| 124 | 
         
            +
                        buffer2 = io.BytesIO()
         
     | 
| 125 | 
         
            +
                        state.image_nbg.save(buffer2, format="PNG")
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    name = state.filename.rsplit(".", 1)[0] + "-mask.png"
         
     | 
| 128 | 
         
            +
                    st.download_button(
         
     | 
| 129 | 
         
            +
                        "下载掩码图片",
         
     | 
| 130 | 
         
            +
                        data=buffer1.getvalue(),
         
     | 
| 131 | 
         
            +
                        file_name=name,
         
     | 
| 132 | 
         
            +
                        use_container_width=True,
         
     | 
| 133 | 
         
            +
                        disabled=not state.mask,
         
     | 
| 134 | 
         
            +
                    )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    name = state.filename.rsplit(".", 1)[0] + "-no-bg.png"
         
     | 
| 137 | 
         
            +
                    st.download_button(
         
     | 
| 138 | 
         
            +
                        "下载前景图片",
         
     | 
| 139 | 
         
            +
                        data=buffer2.getvalue(),
         
     | 
| 140 | 
         
            +
                        file_name=name,
         
     | 
| 141 | 
         
            +
                        use_container_width=True,
         
     | 
| 142 | 
         
            +
                        disabled=not state.image_nbg,
         
     | 
| 143 | 
         
            +
                    )
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            def main():
         
     | 
| 147 | 
         
            +
                st.markdown(
         
     | 
| 148 | 
         
            +
                    '<h1 style="text-align: center; color: white; background: #4b4bff; font-size: 26px; border-radius: .5rem; margin-bottom: 15px;">Ai抠图 (RMBG 2.0)</h1>',
         
     | 
| 149 | 
         
            +
                    unsafe_allow_html=True,
         
     | 
| 150 | 
         
            +
                )
         
     | 
| 151 | 
         
            +
                body_cls = st.columns(2)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                body_cls[0].markdown(
         
     | 
| 154 | 
         
            +
                    "<h6 style='text-align: center;'>原图像</h6>",
         
     | 
| 155 | 
         
            +
                    unsafe_allow_html=True,
         
     | 
| 156 | 
         
            +
                )
         
     | 
| 157 | 
         
            +
                body_cls[1].markdown(
         
     | 
| 158 | 
         
            +
                    "<h6 style='text-align: center;'>处理后</h6>",
         
     | 
| 159 | 
         
            +
                    unsafe_allow_html=True,
         
     | 
| 160 | 
         
            +
                )
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                HEIGHT = 400
         
     | 
| 163 | 
         
            +
                input_container = body_cls[0].container(height=HEIGHT)
         
     | 
| 164 | 
         
            +
                output_container = body_cls[1].container(height=HEIGHT)
         
     | 
| 165 | 
         
            +
                input_image_ph = input_container.empty()
         
     | 
| 166 | 
         
            +
                output_image_ph = output_container.empty()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                # show image
         
     | 
| 169 | 
         
            +
                if state.image:
         
     | 
| 170 | 
         
            +
                    input_image_ph.image(state.image)
         
     | 
| 171 | 
         
            +
                if state.image_nbg:
         
     | 
| 172 | 
         
            +
                    output_image_ph.image(state.image_nbg)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                btm_cls = st.columns(3)
         
     | 
| 175 | 
         
            +
                submit_btn = btm_cls[0].button(
         
     | 
| 176 | 
         
            +
                    ":orange[:material/Cloud_Upload: **上传图片**]", use_container_width=True
         
     | 
| 177 | 
         
            +
                )
         
     | 
| 178 | 
         
            +
                process_ph = btm_cls[1].empty()
         
     | 
| 179 | 
         
            +
                process_btn = process_ph.button(
         
     | 
| 180 | 
         
            +
                    ":rainbow[:material/Hourglass_Empty: **进行处理**]", use_container_width=True
         
     | 
| 181 | 
         
            +
                )
         
     | 
| 182 | 
         
            +
                download_btn = btm_cls[2].button(
         
     | 
| 183 | 
         
            +
                    ":green[:material/Download_2: **下载图片**]",
         
     | 
| 184 | 
         
            +
                    use_container_width=True,
         
     | 
| 185 | 
         
            +
                    disabled=not state.mask,
         
     | 
| 186 | 
         
            +
                )
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                if DEVICE == "CPU":
         
     | 
| 189 | 
         
            +
                    cpu_percent = psutil.cpu_percent(interval=0.5)
         
     | 
| 190 | 
         
            +
                    cpustates = (
         
     | 
| 191 | 
         
            +
                        f"🌟CPU运行会比较慢,请耐心等待一下~🫡(CPU利用率:{cpu_percent:.3f}%)"
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
                    st.caption(
         
     | 
| 194 | 
         
            +
                        f'<p style="text-align: center;">{cpustates}</p>',
         
     | 
| 195 | 
         
            +
                        unsafe_allow_html=True,
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                if submit_btn:
         
     | 
| 199 | 
         
            +
                    upload_image(input_image_ph, output_image_ph)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                if process_btn:
         
     | 
| 202 | 
         
            +
                    if state.image:
         
     | 
| 203 | 
         
            +
                        with output_image_ph.container(), st.spinner(f"正在处理中({DEVICE})..."):
         
     | 
| 204 | 
         
            +
                            mask, image_nbg = process(state.image)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        state.image_nbg = image_nbg
         
     | 
| 207 | 
         
            +
                        state.mask = mask
         
     | 
| 208 | 
         
            +
                        output_container.image(image_nbg)
         
     | 
| 209 | 
         
            +
                        st.rerun()
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        st.toast("请上传图片")
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                if download_btn:
         
     | 
| 214 | 
         
            +
                    download_image()
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 218 | 
         
            +
                main()
         
     | 
    	
        process.py
    ADDED
    
    | 
         @@ -0,0 +1,48 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import streamlit as st
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
            from torchvision import transforms
         
     | 
| 6 | 
         
            +
            from transformers import AutoModelForImageSegmentation
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @st.cache_resource
         
     | 
| 10 | 
         
            +
            def load_model(model_id_or_path="briaai/RMBG-2.0", precision=0, device="cuda"):
         
     | 
| 11 | 
         
            +
                model = AutoModelForImageSegmentation.from_pretrained(
         
     | 
| 12 | 
         
            +
                    model_id_or_path, trust_remote_code=True
         
     | 
| 13 | 
         
            +
                )
         
     | 
| 14 | 
         
            +
                torch.set_float32_matmul_precision(["high", "highest"][precision])
         
     | 
| 15 | 
         
            +
                model.to(device)
         
     | 
| 16 | 
         
            +
                _ = model.eval()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                # Data settings
         
     | 
| 19 | 
         
            +
                image_size = (1024, 1024)
         
     | 
| 20 | 
         
            +
                transform_image = transforms.Compose(
         
     | 
| 21 | 
         
            +
                    [
         
     | 
| 22 | 
         
            +
                        transforms.Resize(image_size),
         
     | 
| 23 | 
         
            +
                        transforms.ToTensor(),
         
     | 
| 24 | 
         
            +
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
         
     | 
| 25 | 
         
            +
                    ]
         
     | 
| 26 | 
         
            +
                )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                return model, transform_image
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def process(image: Image.Image) -> Image.Image:
         
     | 
| 32 | 
         
            +
                if "RMBG-2.0" not in os.listdir("."):
         
     | 
| 33 | 
         
            +
                    os.system(
         
     | 
| 34 | 
         
            +
                        "modelscope download --model AI-ModelScope/RMBG-2.0 --local_dir RMBG-2.0 --exclude *.onnx *.bin"
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 37 | 
         
            +
                precision = 0
         
     | 
| 38 | 
         
            +
                model, transform = load_model("RMBG-2.0", precision=precision, device=device)
         
     | 
| 39 | 
         
            +
                image = image.copy()
         
     | 
| 40 | 
         
            +
                input_images = transform(image).unsqueeze(0).to(device)
         
     | 
| 41 | 
         
            +
                with torch.no_grad():
         
     | 
| 42 | 
         
            +
                    preds = model(input_images)[-1].sigmoid().cpu()
         
     | 
| 43 | 
         
            +
                    pred = preds[0].squeeze()
         
     | 
| 44 | 
         
            +
                pred_pil = transforms.ToPILImage()(pred)
         
     | 
| 45 | 
         
            +
                mask = pred_pil.resize(image.size)
         
     | 
| 46 | 
         
            +
                image.putalpha(mask)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                return mask, image
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch
         
     | 
| 2 | 
         
            +
            torchvision
         
     | 
| 3 | 
         
            +
            pillow
         
     | 
| 4 | 
         
            +
            kornia
         
     | 
| 5 | 
         
            +
            transformers
         
     | 
| 6 | 
         
            +
            streamlit
         
     | 
| 7 | 
         
            +
            huggingface
         
     | 
| 8 | 
         
            +
            timm
         
     | 
| 9 | 
         
            +
            modelscope
         
     | 
| 10 | 
         
            +
            psutil
         
     |