Spaces:
Paused
Paused
| import json | |
| import time | |
| import uuid | |
| from typing import Optional | |
| from open_webui.internal.db import Base, get_db | |
| from open_webui.models.tags import TagModel, Tag, Tags | |
| from pydantic import BaseModel, ConfigDict | |
| from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON | |
| from sqlalchemy import or_, func, select, and_, text | |
| from sqlalchemy.sql import exists | |
| #################### | |
| # Message DB Schema | |
| #################### | |
| class MessageReaction(Base): | |
| __tablename__ = "message_reaction" | |
| id = Column(Text, primary_key=True) | |
| user_id = Column(Text) | |
| message_id = Column(Text) | |
| name = Column(Text) | |
| created_at = Column(BigInteger) | |
| class MessageReactionModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| user_id: str | |
| message_id: str | |
| name: str | |
| created_at: int # timestamp in epoch | |
| class Message(Base): | |
| __tablename__ = "message" | |
| id = Column(Text, primary_key=True) | |
| user_id = Column(Text) | |
| channel_id = Column(Text, nullable=True) | |
| parent_id = Column(Text, nullable=True) | |
| content = Column(Text) | |
| data = Column(JSON, nullable=True) | |
| meta = Column(JSON, nullable=True) | |
| created_at = Column(BigInteger) # time_ns | |
| updated_at = Column(BigInteger) # time_ns | |
| class MessageModel(BaseModel): | |
| model_config = ConfigDict(from_attributes=True) | |
| id: str | |
| user_id: str | |
| channel_id: Optional[str] = None | |
| parent_id: Optional[str] = None | |
| content: str | |
| data: Optional[dict] = None | |
| meta: Optional[dict] = None | |
| created_at: int # timestamp in epoch | |
| updated_at: int # timestamp in epoch | |
| #################### | |
| # Forms | |
| #################### | |
| class MessageForm(BaseModel): | |
| content: str | |
| parent_id: Optional[str] = None | |
| data: Optional[dict] = None | |
| meta: Optional[dict] = None | |
| class Reactions(BaseModel): | |
| name: str | |
| user_ids: list[str] | |
| count: int | |
| class MessageResponse(MessageModel): | |
| latest_reply_at: Optional[int] | |
| reply_count: int | |
| reactions: list[Reactions] | |
| class MessageTable: | |
| def insert_new_message( | |
| self, form_data: MessageForm, channel_id: str, user_id: str | |
| ) -> Optional[MessageModel]: | |
| with get_db() as db: | |
| id = str(uuid.uuid4()) | |
| ts = int(time.time_ns()) | |
| message = MessageModel( | |
| **{ | |
| "id": id, | |
| "user_id": user_id, | |
| "channel_id": channel_id, | |
| "parent_id": form_data.parent_id, | |
| "content": form_data.content, | |
| "data": form_data.data, | |
| "meta": form_data.meta, | |
| "created_at": ts, | |
| "updated_at": ts, | |
| } | |
| ) | |
| result = Message(**message.model_dump()) | |
| db.add(result) | |
| db.commit() | |
| db.refresh(result) | |
| return MessageModel.model_validate(result) if result else None | |
| def get_message_by_id(self, id: str) -> Optional[MessageResponse]: | |
| with get_db() as db: | |
| message = db.get(Message, id) | |
| if not message: | |
| return None | |
| reactions = self.get_reactions_by_message_id(id) | |
| replies = self.get_replies_by_message_id(id) | |
| return MessageResponse( | |
| **{ | |
| **MessageModel.model_validate(message).model_dump(), | |
| "latest_reply_at": replies[0].created_at if replies else None, | |
| "reply_count": len(replies), | |
| "reactions": reactions, | |
| } | |
| ) | |
| def get_replies_by_message_id(self, id: str) -> list[MessageModel]: | |
| with get_db() as db: | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(parent_id=id) | |
| .order_by(Message.created_at.desc()) | |
| .all() | |
| ) | |
| return [MessageModel.model_validate(message) for message in all_messages] | |
| def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: | |
| with get_db() as db: | |
| return [ | |
| message.user_id | |
| for message in db.query(Message).filter_by(parent_id=id).all() | |
| ] | |
| def get_messages_by_channel_id( | |
| self, channel_id: str, skip: int = 0, limit: int = 50 | |
| ) -> list[MessageModel]: | |
| with get_db() as db: | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id, parent_id=None) | |
| .order_by(Message.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| return [MessageModel.model_validate(message) for message in all_messages] | |
| def get_messages_by_parent_id( | |
| self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 | |
| ) -> list[MessageModel]: | |
| with get_db() as db: | |
| message = db.get(Message, parent_id) | |
| if not message: | |
| return [] | |
| all_messages = ( | |
| db.query(Message) | |
| .filter_by(channel_id=channel_id, parent_id=parent_id) | |
| .order_by(Message.created_at.desc()) | |
| .offset(skip) | |
| .limit(limit) | |
| .all() | |
| ) | |
| # If length of all_messages is less than limit, then add the parent message | |
| if len(all_messages) < limit: | |
| all_messages.append(message) | |
| return [MessageModel.model_validate(message) for message in all_messages] | |
| def update_message_by_id( | |
| self, id: str, form_data: MessageForm | |
| ) -> Optional[MessageModel]: | |
| with get_db() as db: | |
| message = db.get(Message, id) | |
| message.content = form_data.content | |
| message.data = form_data.data | |
| message.meta = form_data.meta | |
| message.updated_at = int(time.time_ns()) | |
| db.commit() | |
| db.refresh(message) | |
| return MessageModel.model_validate(message) if message else None | |
| def add_reaction_to_message( | |
| self, id: str, user_id: str, name: str | |
| ) -> Optional[MessageReactionModel]: | |
| with get_db() as db: | |
| reaction_id = str(uuid.uuid4()) | |
| reaction = MessageReactionModel( | |
| id=reaction_id, | |
| user_id=user_id, | |
| message_id=id, | |
| name=name, | |
| created_at=int(time.time_ns()), | |
| ) | |
| result = MessageReaction(**reaction.model_dump()) | |
| db.add(result) | |
| db.commit() | |
| db.refresh(result) | |
| return MessageReactionModel.model_validate(result) if result else None | |
| def get_reactions_by_message_id(self, id: str) -> list[Reactions]: | |
| with get_db() as db: | |
| all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() | |
| reactions = {} | |
| for reaction in all_reactions: | |
| if reaction.name not in reactions: | |
| reactions[reaction.name] = { | |
| "name": reaction.name, | |
| "user_ids": [], | |
| "count": 0, | |
| } | |
| reactions[reaction.name]["user_ids"].append(reaction.user_id) | |
| reactions[reaction.name]["count"] += 1 | |
| return [Reactions(**reaction) for reaction in reactions.values()] | |
| def remove_reaction_by_id_and_user_id_and_name( | |
| self, id: str, user_id: str, name: str | |
| ) -> bool: | |
| with get_db() as db: | |
| db.query(MessageReaction).filter_by( | |
| message_id=id, user_id=user_id, name=name | |
| ).delete() | |
| db.commit() | |
| return True | |
| def delete_reactions_by_id(self, id: str) -> bool: | |
| with get_db() as db: | |
| db.query(MessageReaction).filter_by(message_id=id).delete() | |
| db.commit() | |
| return True | |
| def delete_replies_by_id(self, id: str) -> bool: | |
| with get_db() as db: | |
| db.query(Message).filter_by(parent_id=id).delete() | |
| db.commit() | |
| return True | |
| def delete_message_by_id(self, id: str) -> bool: | |
| with get_db() as db: | |
| db.query(Message).filter_by(id=id).delete() | |
| # Delete all reactions to this message | |
| db.query(MessageReaction).filter_by(message_id=id).delete() | |
| db.commit() | |
| return True | |
| Messages = MessageTable() | |