vdkch/main.py

313 lines
12 KiB
Python
Raw Normal View History

2025-02-10 15:53:52 +03:00
# main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional
import uuid
import os
from datetime import datetime
import requests
from qdrant_client import QdrantClient
from qdrant_client.http import models
from fastapi.middleware.cors import CORSMiddleware
# Database setup
from sqlalchemy import create_engine, Column, String, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# Image processing
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
import logging
from qdrant_client.http.models import VectorParams, Distance
# Настройка базового логирования (например, вывод в консоль)
logging.basicConfig(
level=logging.INFO, # Можно изменить уровень, например, на DEBUG
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Задайте имя коллекции
COLLECTION_NAME = "posts"
# Определите размер вектора. Этот размер должен соответствовать длине объединённого эмбеддинга текста и изображения.
VECTOR_SIZE = 1280 # Пример: поменяйте на актуальное значение для вашего случая
# Configuration
DATABASE_URL = "sqlite:///./imageboard.db"
QDRANT_URL = "http://localhost:6333"
OLLAMA_URL = "http://localhost:11434"
EMBEDDING_MODEL = "nomic-embed-text" # Локальная модель через Ollama
IMAGE_MODEL = "openai/clip-vit-base-patch32" # Локальная CLIP модель
IMAGE_SIZE = (224, 224)
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
# Initialize components
Base = declarative_base()
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
#Base.metadata.drop_all(bind=engine) # Удаляет все таблицы
#Base.metadata.create_all(bind=engine) # Создаёт таблицы заново
# Инициализация CLIP для изображений
clip_model = CLIPModel.from_pretrained(IMAGE_MODEL)
clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = clip_model.to(device)
# Qdrant клиент
qdrant_client = QdrantClient(QDRANT_URL)
def ensure_collection_exists():
try:
# Попытка получить коллекцию. Если коллекция не существует, Qdrant выбросит исключение.
qdrant_client.get_collection(collection_name=COLLECTION_NAME)
logger.info("Коллекция '%s' существует.", COLLECTION_NAME)
except Exception as e:
logger.info("Коллекция '%s' не найдена. Создаём коллекцию...", COLLECTION_NAME)
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE # Или другой подходящий тип расстояния
)
)
logger.info("Коллекция '%s' создана.", COLLECTION_NAME)
# Вызываем функцию при инициализации приложения, например, в начале main.py
ensure_collection_exists()
app = FastAPI()
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name=UPLOAD_DIR)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Разрешить все источники
allow_credentials=True,
allow_methods=["*"], # Разрешить все методы
allow_headers=["*"], # Разрешить все заголовки
)
# Database models
class Post(Base):
__tablename__ = "posts"
id = Column(String, primary_key=True, index=True)
text = Column(Text, nullable=True)
image = Column(String, nullable=True)
created_at = Column(DateTime)
Base.metadata.create_all(bind=engine)
# Pydantic model для ответа
class PostResponse(BaseModel):
id: str
text: Optional[str] = None
image: Optional[str] = None
created_at: datetime
vector: Optional[List[float]] = None
class Config:
orm_mode = True
# Pydantic модель для запроса вектора
class VectorQuery(BaseModel):
vector: List[float]
# Utility functions
def generate_text_embedding(text: str) -> List[float]:
response = requests.post(
f"{OLLAMA_URL}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text}
)
if response.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding generation failed")
return response.json()["embedding"]
def generate_image_embedding(image_bytes: bytes) -> List[float]:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = clip_processor(
images=image,
return_tensors="pt",
padding=True
).to(device)
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
return features.cpu().numpy().tolist()[0]
def process_image(image_bytes: bytes) -> bytes:
img = Image.open(io.BytesIO(image_bytes))
img = img.convert("RGB")
img = img.resize(IMAGE_SIZE)
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
return buffer.getvalue()
2025-02-12 00:57:29 +03:00
# API endpoints
2025-02-10 15:53:52 +03:00
# API endpoints
@app.post("/posts/", response_model=PostResponse)
async def create_post(
text: Optional[str] = Form(None),
image: Optional[UploadFile] = File(None)
):
db = SessionLocal()
try:
post_id = str(uuid.uuid4())
image_path = None
2025-02-12 00:57:29 +03:00
thumbnail_path = None # Placeholder for thumbnail
2025-02-10 15:53:52 +03:00
embeddings = []
if text:
logger.info("Генерация эмбеддинга для текста")
text_embedding = generate_text_embedding(text)
embeddings.extend(text_embedding)
if image:
logger.info("Обработка изображения")
image_bytes = await image.read()
2025-02-12 00:57:29 +03:00
# Save original image
2025-02-10 15:53:52 +03:00
image_path = f"{UPLOAD_DIR}/{post_id}.jpg"
with open(image_path, "wb") as f:
2025-02-12 00:57:29 +03:00
f.write(image_bytes)
# Create processed image for embeddings
processed_image = process_image(image_bytes) # Assume this resizes/image processing
2025-02-10 15:53:52 +03:00
2025-02-12 00:57:29 +03:00
# Generate thumbnail as placeholder (example implementation)
#thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation
#thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg"
#with open(thumbnail_path, "wb") as f:
# f.write(thumbnail)
# Generate embeddings from processed image
2025-02-10 15:53:52 +03:00
image_embedding = generate_image_embedding(processed_image)
embeddings.extend(image_embedding)
logger.info("Сохранение данных в Qdrant")
qdrant_client.upsert(
collection_name="posts",
points=[models.PointStruct(
id=post_id,
vector=embeddings,
payload={"post_id": post_id}
)]
)
logger.info("Сохранение поста в базу данных")
db_post = Post(
id=post_id,
text=text,
image=image_path,
2025-02-12 00:57:29 +03:00
#thumbnail=thumbnail_path, Add thumbnail field to your Post model
2025-02-10 15:53:52 +03:00
created_at=datetime.now()
)
db.add(db_post)
db.commit()
db.refresh(db_post)
response = PostResponse(
id=db_post.id,
text=db_post.text,
image=db_post.image,
2025-02-12 00:57:29 +03:00
#thumbnail=db_post.thumbnail, Update response model to include thumbnail
2025-02-10 15:53:52 +03:00
created_at=db_post.created_at,
vector=embeddings
)
logger.info("Пост успешно создан: %s", response)
return response
except Exception as e:
db.rollback()
logger.exception("Ошибка при создании поста")
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
2025-02-12 00:57:29 +03:00
2025-02-10 15:53:52 +03:00
@app.get("/search/")
async def search_posts(
text: Optional[str] = None,
image: Optional[UploadFile] = File(None)
):
try:
query_embedding = []
if text:
logger.info("Генерация эмбеддинга для текста (поиск)")
text_embedding = generate_text_embedding(text)
query_embedding.extend(text_embedding)
if image:
logger.info("Генерация эмбеддинга для изображения (поиск)")
image_bytes = await image.read()
processed_image = process_image(image_bytes)
image_embedding = generate_image_embedding(processed_image)
query_embedding.extend(image_embedding)
logger.info("Выполнение поиска в Qdrant")
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query_embedding,
limit=10
)
logger.info("Поиск завершён. Найдено результатов: %d", len(search_results))
return [result.payload for result in search_results]
except Exception as e:
logger.exception("Ошибка при поиске постов")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/posts/", response_model=List[PostResponse])
async def get_all_posts():
db = SessionLocal()
try:
posts = db.query(Post).all()
return posts
finally:
db.close()
# Новый endpoint: получение "древа" постов по вектору пользователя
@app.post("/posts/tree", response_model=List[PostResponse])
async def get_posts_tree(query: VectorQuery):
# Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству
search_results = qdrant_client.search(
collection_name="posts",
query_vector=query.vector,
limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов
)
print("search_results")
print(search_results)
# Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним)
post_ids = [result.payload.get("post_id") for result in search_results]
print("post_ids")
print(post_ids)
db = SessionLocal()
try:
# Получаем все посты из БД по списку ID
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
print("posts")
print(posts)
# Создаём словарь для сохранения соответствия post_id -> post
posts_dict = {post.id: post for post in posts}
print("posts_dict")
print(posts_dict)
# Восстанавливаем порядок, используя список post_ids
ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict]
print("ordered_posts")
print(ordered_posts)
return ordered_posts
finally:
db.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)